From e3afa875b9939ac27ad74b32244ecfa9005e9e1b Mon Sep 17 00:00:00 2001 From: Mateusz Bilski Date: Wed, 16 Aug 2023 13:11:31 +0200 Subject: [PATCH] Sync fork (#8) * Sync * Sync --- google/appengine_gen1.go | 1 - google/appengine_gen2_flex.go | 1 - google/default.go | 37 +- google/doc.go | 65 ++- google/google.go | 15 +- google/internal/externalaccount/aws.go | 105 +++- google/internal/externalaccount/aws_test.go | 531 ++++++++++++++++-- .../externalaccount/basecredentials.go | 36 +- .../externalaccount/basecredentials_test.go | 135 ----- .../externalaccount/impersonate_test.go | 4 +- oauth2.go | 31 + token.go | 10 + 12 files changed, 701 insertions(+), 270 deletions(-) diff --git a/google/appengine_gen1.go b/google/appengine_gen1.go index 16c6c6b90..e61587945 100644 --- a/google/appengine_gen1.go +++ b/google/appengine_gen1.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build appengine -// +build appengine // This file applies to App Engine first generation runtimes (<= Go 1.9). diff --git a/google/appengine_gen2_flex.go b/google/appengine_gen2_flex.go index a7e27b3d2..9c79aa0a0 100644 --- a/google/appengine_gen2_flex.go +++ b/google/appengine_gen2_flex.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !appengine -// +build !appengine // This file applies to App Engine second generation runtimes (>= Go 1.11) and App Engine flexible. diff --git a/google/default.go b/google/default.go index 7ed02cd41..2cf71f0f9 100644 --- a/google/default.go +++ b/google/default.go @@ -8,17 +8,19 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" "net/http" "os" "path/filepath" "runtime" + "time" "cloud.google.com/go/compute/metadata" "golang.org/x/oauth2" "golang.org/x/oauth2/authhandler" ) +const adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc" + // Credentials holds Google credentials, including "Application Default Credentials". // For more details, see: // https://developers.google.com/accounts/docs/application-default-credentials @@ -62,6 +64,18 @@ type CredentialsParams struct { // PKCE is used to support PKCE flow. Optional for 3LO flow. PKCE *authhandler.PKCEParams + + // The OAuth2 TokenURL default override. This value overrides the default TokenURL, + // unless explicitly specified by the credentials config file. Optional. + TokenURL string + + // EarlyTokenRefresh is the amount of time before a token expires that a new + // token will be preemptively fetched. If unset the default value is 10 + // seconds. + // + // Note: This option is currently only respected when using credentials + // fetched from the GCE metadata server. + EarlyTokenRefresh time.Duration } func (params CredentialsParams) deepCopy() CredentialsParams { @@ -127,17 +141,15 @@ func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsPar // Second, try a well-known file. filename := wellKnownFile() - if creds, err := readCredentialsFile(ctx, filename, params); err == nil { - return creds, nil - } else if !os.IsNotExist(err) { - return nil, fmt.Errorf("google: error getting credentials using well-known file (%v): %v", filename, err) + if b, err := os.ReadFile(filename); err == nil { + return CredentialsFromJSONWithParams(ctx, b, params) } // Third, if we're on a Google App Engine standard first generation runtime (<= Go 1.9) // use those credentials. App Engine standard second generation runtimes (>= Go 1.11) // and App Engine flexible use ComputeTokenSource and the metadata server. if appengineTokenFunc != nil { - return &DefaultCredentials{ + return &Credentials{ ProjectID: appengineAppIDFunc(ctx), TokenSource: AppEngineTokenSource(ctx, params.Scopes...), }, nil @@ -147,15 +159,14 @@ func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsPar // or App Engine flexible, use the metadata server. if metadata.OnGCE() { id, _ := metadata.ProjectID() - return &DefaultCredentials{ + return &Credentials{ ProjectID: id, - TokenSource: ComputeTokenSource("", params.Scopes...), + TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...), }, nil } // None are found; return helpful error. - const url = "https://developers.google.com/accounts/docs/application-default-credentials" - return nil, fmt.Errorf("google: could not find default credentials. See %v for more information.", url) + return nil, fmt.Errorf("google: could not find default credentials. See %v for more information", adcSetupURL) } // FindDefaultCredentials invokes FindDefaultCredentialsWithParams with the specified scopes. @@ -194,7 +205,7 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params return nil, err } ts = newErrWrappingTokenSource(ts) - return &DefaultCredentials{ + return &Credentials{ ProjectID: f.ProjectID, TokenSource: ts, JSON: jsonData, @@ -216,8 +227,8 @@ func wellKnownFile() string { return filepath.Join(guessUnixHomeDir(), ".config", "gcloud", f) } -func readCredentialsFile(ctx context.Context, filename string, params CredentialsParams) (*DefaultCredentials, error) { - b, err := ioutil.ReadFile(filename) +func readCredentialsFile(ctx context.Context, filename string, params CredentialsParams) (*Credentials, error) { + b, err := os.ReadFile(filename) if err != nil { return nil, err } diff --git a/google/doc.go b/google/doc.go index b3e7bc85c..ca717634a 100644 --- a/google/doc.go +++ b/google/doc.go @@ -26,7 +26,7 @@ // // Using workload identity federation, your application can access Google Cloud // resources from Amazon Web Services (AWS), Microsoft Azure or any identity -// provider that supports OpenID Connect (OIDC). +// provider that supports OpenID Connect (OIDC) or SAML 2.0. // Traditionally, applications running outside Google Cloud have used service // account keys to access Google Cloud resources. Using identity federation, // you can allow your workload to impersonate a service account. @@ -36,26 +36,75 @@ // Follow the detailed instructions on how to configure Workload Identity Federation // in various platforms: // -// Amazon Web Services (AWS): https://cloud.google.com/iam/docs/access-resources-aws -// Microsoft Azure: https://cloud.google.com/iam/docs/access-resources-azure -// OIDC identity provider: https://cloud.google.com/iam/docs/access-resources-oidc +// Amazon Web Services (AWS): https://cloud.google.com/iam/docs/workload-identity-federation-with-other-clouds#aws +// Microsoft Azure: https://cloud.google.com/iam/docs/workload-identity-federation-with-other-clouds#azure +// OIDC identity provider: https://cloud.google.com/iam/docs/workload-identity-federation-with-other-providers#oidc +// SAML 2.0 identity provider: https://cloud.google.com/iam/docs/workload-identity-federation-with-other-providers#saml // // For OIDC and SAML providers, the library can retrieve tokens in three ways: // from a local file location (file-sourced credentials), from a server // (URL-sourced credentials), or from a local executable (executable-sourced // credentials). // For file-sourced credentials, a background process needs to be continuously -// refreshing the file location with a new OIDC token prior to expiration. +// refreshing the file location with a new OIDC/SAML token prior to expiration. // For tokens with one hour lifetimes, the token needs to be updated in the file // every hour. The token can be stored directly as plain text or in JSON format. // For URL-sourced credentials, a local server needs to host a GET endpoint to -// return the OIDC token. The response can be in plain text or JSON. +// return the OIDC/SAML token. The response can be in plain text or JSON. // Additional required request headers can also be specified. // For executable-sourced credentials, an application needs to be available to -// output the OIDC token and other information in a JSON format. +// output the OIDC/SAML token and other information in a JSON format. // For more information on how these work (and how to implement // executable-sourced credentials), please check out: -// https://cloud.google.com/iam/docs/using-workload-identity-federation#oidc +// https://cloud.google.com/iam/docs/workload-identity-federation-with-other-providers#create_a_credential_configuration +// +// Note that this library does not perform any validation on the token_url, token_info_url, +// or service_account_impersonation_url fields of the credential configuration. +// It is not recommended to use a credential configuration that you did not generate with +// the gcloud CLI unless you verify that the URL fields point to a googleapis.com domain. +// +// # Workforce Identity Federation +// +// Workforce identity federation lets you use an external identity provider (IdP) to +// authenticate and authorize a workforce—a group of users, such as employees, partners, +// and contractors—using IAM, so that the users can access Google Cloud services. +// Workforce identity federation extends Google Cloud's identity capabilities to support +// syncless, attribute-based single sign on. +// +// With workforce identity federation, your workforce can access Google Cloud resources +// using an external identity provider (IdP) that supports OpenID Connect (OIDC) or +// SAML 2.0 such as Azure Active Directory (Azure AD), Active Directory Federation +// Services (AD FS), Okta, and others. +// +// Follow the detailed instructions on how to configure Workload Identity Federation +// in various platforms: +// +// Azure AD: https://cloud.google.com/iam/docs/workforce-sign-in-azure-ad +// Okta: https://cloud.google.com/iam/docs/workforce-sign-in-okta +// OIDC identity provider: https://cloud.google.com/iam/docs/configuring-workforce-identity-federation#oidc +// SAML 2.0 identity provider: https://cloud.google.com/iam/docs/configuring-workforce-identity-federation#saml +// +// For workforce identity federation, the library can retrieve tokens in three ways: +// from a local file location (file-sourced credentials), from a server +// (URL-sourced credentials), or from a local executable (executable-sourced +// credentials). +// For file-sourced credentials, a background process needs to be continuously +// refreshing the file location with a new OIDC/SAML token prior to expiration. +// For tokens with one hour lifetimes, the token needs to be updated in the file +// every hour. The token can be stored directly as plain text or in JSON format. +// For URL-sourced credentials, a local server needs to host a GET endpoint to +// return the OIDC/SAML token. The response can be in plain text or JSON. +// Additional required request headers can also be specified. +// For executable-sourced credentials, an application needs to be available to +// output the OIDC/SAML token and other information in a JSON format. +// For more information on how these work (and how to implement +// executable-sourced credentials), please check out: +// https://cloud.google.com/iam/docs/workforce-obtaining-short-lived-credentials#generate_a_configuration_file_for_non-interactive_sign-in +// +// Note that this library does not perform any validation on the token_url, token_info_url, +// or service_account_impersonation_url fields of the credential configuration. +// It is not recommended to use a credential configuration that you did not generate with +// the gcloud CLI unless you verify that the URL fields point to a googleapis.com domain. // // # Credentials // diff --git a/google/google.go b/google/google.go index 8df0c493e..cc1223889 100644 --- a/google/google.go +++ b/google/google.go @@ -26,6 +26,9 @@ var Endpoint = oauth2.Endpoint{ AuthStyle: oauth2.AuthStyleInParams, } +// MTLSTokenURL is Google's OAuth 2.0 default mTLS endpoint. +const MTLSTokenURL = "https://oauth2.mtls.googleapis.com/token" + // JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow. const JWTTokenURL = "https://oauth2.googleapis.com/token" @@ -172,7 +175,11 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar cfg.Endpoint.AuthURL = Endpoint.AuthURL } if cfg.Endpoint.TokenURL == "" { - cfg.Endpoint.TokenURL = Endpoint.TokenURL + if params.TokenURL != "" { + cfg.Endpoint.TokenURL = params.TokenURL + } else { + cfg.Endpoint.TokenURL = Endpoint.TokenURL + } } tok := &oauth2.Token{RefreshToken: f.RefreshToken} return cfg.TokenSource(ctx, tok), nil @@ -224,7 +231,11 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar // Further information about retrieving access tokens from the GCE metadata // server can be found at https://cloud.google.com/compute/docs/authentication. func ComputeTokenSource(account string, scope ...string) oauth2.TokenSource { - return oauth2.ReuseTokenSource(nil, computeSource{account: account, scopes: scope}) + return computeTokenSource(account, 0, scope...) +} + +func computeTokenSource(account string, earlyExpiry time.Duration, scope ...string) oauth2.TokenSource { + return oauth2.ReuseTokenSourceWithExpiry(nil, computeSource{account: account, scopes: scope}, earlyExpiry) } type computeSource struct { diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index e917195d5..2bf3202b2 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -62,6 +62,13 @@ const ( // The AWS authorization header name for the auto-generated date. awsDateHeader = "x-amz-date" + // Supported AWS configuration environment variables. + awsAccessKeyId = "AWS_ACCESS_KEY_ID" + awsDefaultRegion = "AWS_DEFAULT_REGION" + awsRegion = "AWS_REGION" + awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY" + awsSessionToken = "AWS_SESSION_TOKEN" + awsTimeFormatLong = "20060102T150405Z" awsTimeFormatShort = "20060102" ) @@ -267,6 +274,49 @@ type awsRequest struct { Headers []awsRequestHeader `json:"headers"` } +func (cs awsCredentialSource) validateMetadataServers() error { + if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil { + return err + } + if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil { + return err + } + return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url") +} + +var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"} + +func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool { + if metadataUrl == "" { + // Zero value means use default, which is valid. + return true + } + + u, err := url.Parse(metadataUrl) + if err != nil { + // Unparseable URL means invalid + return false + } + + for _, validHostname := range validHostnames { + if u.Hostname() == validHostname { + // If it's one of the valid hostnames, everything is good + return true + } + } + + // hostname not found in our allowlist, so not valid + return false +} + +func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error { + if !cs.isValidMetadataServer(metadataUrl) { + return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName) + } + + return nil +} + func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) { if cs.client == nil { cs.client = oauth2.NewClient(cs.ctx, nil) @@ -274,16 +324,33 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro return cs.client.Do(req.WithContext(cs.ctx)) } +func canRetrieveRegionFromEnvironment() bool { + // The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is + // required. + return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != "" +} + +func canRetrieveSecurityCredentialFromEnvironment() bool { + // Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available. + return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != "" +} + +func shouldUseMetadataServer() bool { + return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment() +} + func (cs awsCredentialSource) subjectToken() (string, error) { if cs.requestSigner == nil { - awsSessionToken, err := cs.getAWSSessionToken() - if err != nil { - return "", err - } - headers := make(map[string]string) - if awsSessionToken != "" { - headers[awsIMDSv2SessionTokenHeader] = awsSessionToken + if shouldUseMetadataServer() { + awsSessionToken, err := cs.getAWSSessionToken() + if err != nil { + return "", err + } + + if awsSessionToken != "" { + headers[awsIMDSv2SessionTokenHeader] = awsSessionToken + } } awsSecurityCredentials, err := cs.getSecurityCredentials(headers) @@ -389,11 +456,11 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { } func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { - if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" { - return envAwsRegion, nil - } - if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" { - return envAwsRegion, nil + if canRetrieveRegionFromEnvironment() { + if envAwsRegion := getenv(awsRegion); envAwsRegion != "" { + return envAwsRegion, nil + } + return getenv("AWS_DEFAULT_REGION"), nil } if cs.RegionURL == "" { @@ -434,14 +501,12 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err } func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) { - if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" { - if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" { - return awsSecurityCredentials{ - AccessKeyID: accessKeyID, - SecretAccessKey: secretAccessKey, - SecurityToken: getenv("AWS_SESSION_TOKEN"), - }, nil - } + if canRetrieveSecurityCredentialFromEnvironment() { + return awsSecurityCredentials{ + AccessKeyID: getenv(awsAccessKeyId), + SecretAccessKey: getenv(awsSecretAccessKey), + SecurityToken: getenv(awsSessionToken), + }, nil } roleName, err := cs.getMetadataRoleName(headers) diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index 093438925..058b00424 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -474,6 +474,38 @@ func createDefaultAwsTestServer() *testAwsServer { ) } +func createDefaultAwsTestServerWithImdsv2(t *testing.T) *testAwsServer { + validateSessionTokenHeaders := func(r *http.Request) { + if r.URL.Path == "/latest/api/token" { + headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader) + if headerValue != awsIMDSv2SessionTtl { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl) + } + } else { + headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader) + if headerValue != "sessiontoken" { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken") + } + } + } + + return createAwsTestServer( + "/latest/meta-data/iam/security-credentials", + "/latest/meta-data/placement/availability-zone", + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "/latest/api/token", + "gcp-aws-role", + "us-east-2b", + map[string]string{ + "SecretAccessKey": secretAccessKey, + "AccessKeyId": accessKeyID, + "Token": securityToken, + }, + "sessiontoken", + validateSessionTokenHeaders, + ) +} + func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch p := r.URL.Path; p { case server.url: @@ -553,16 +585,25 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security func TestAWSCredential_BasicRequest(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() - getenv = setEnvironment(map[string]string{}) oldNow := now - defer func() { now = oldNow }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{}) now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -588,46 +629,27 @@ func TestAWSCredential_BasicRequest(t *testing.T) { } func TestAWSCredential_IMDSv2(t *testing.T) { - validateSessionTokenHeaders := func(r *http.Request) { - if r.URL.Path == "/latest/api/token" { - headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader) - if headerValue != awsIMDSv2SessionTtl { - t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl) - } - } else { - headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader) - if headerValue != "sessiontoken" { - t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken") - } - } - } - - server := createAwsTestServer( - "/latest/meta-data/iam/security-credentials", - "/latest/meta-data/placement/availability-zone", - "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", - "/latest/api/token", - "gcp-aws-role", - "us-east-2b", - map[string]string{ - "SecretAccessKey": secretAccessKey, - "AccessKeyId": accessKeyID, - "Token": securityToken, - }, - "sessiontoken", - validateSessionTokenHeaders, - ) + server := createDefaultAwsTestServerWithImdsv2(t) ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() - getenv = setEnvironment(map[string]string{}) oldNow := now - defer func() { now = oldNow }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{}) now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -655,17 +677,26 @@ func TestAWSCredential_IMDSv2(t *testing.T) { func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } delete(server.Credentials, "Token") tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() - getenv = setEnvironment(map[string]string{}) oldNow := now - defer func() { now = oldNow }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{}) now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -693,20 +724,29 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{ "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_REGION": "us-west-1", }) - oldNow := now - defer func() { now = oldNow }() now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -734,20 +774,29 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{ "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", - "AWS_DEFAULT_REGION": "us-west-1", + "AWS_REGION": "us-west-1", }) - oldNow := now - defer func() { now = oldNow }() now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -774,21 +823,30 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{ "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_REGION": "us-west-1", "AWS_DEFAULT_REGION": "us-east-1", }) - oldNow := now - defer func() { now = oldNow }() now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -815,16 +873,25 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { func TestAWSCredential_RequestWithBadVersion(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource.EnvironmentID = "aws3" oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} - _, err := tfc.parse(context.Background()) + _, err = tfc.parse(context.Background()) if err == nil { t.Fatalf("parse() should have failed") } @@ -836,14 +903,23 @@ func TestAWSCredential_RequestWithBadVersion(t *testing.T) { func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource.RegionURL = "" oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -863,14 +939,23 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteRegion = notFound tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -890,6 +975,10 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("{}")) } @@ -898,8 +987,13 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -919,6 +1013,10 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`)) } @@ -927,8 +1025,13 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -948,14 +1051,23 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource.URL = "" oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -975,14 +1087,23 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteRolename = notFound tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -1002,14 +1123,23 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteSecurityCredentials = notFound tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -1025,3 +1155,290 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { t.Errorf("subjectToken = %q, want %q", got, want) } } + +func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) { + server := createDefaultAwsTestServer() + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("Metadata server should not have been called.") + })) + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + tfc.CredentialSource.IMDSv2SessionTokenURL = metadataTs.URL + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + "AKIDEXAMPLE", + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": accessKeyID, + "AWS_SECRET_ACCESS_KEY": secretAccessKey, + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", + accessKeyID, + secretAccessKey, + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_Validations(t *testing.T) { + var metadataServerValidityTests = []struct { + name string + credSource CredentialSource + errText string + }{ + { + name: "No Metadata Server URLs", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "", + URL: "", + IMDSv2SessionTokenURL: "", + }, + }, { + name: "IPv4 Metadata Server URLs", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone", + URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token", + }, + }, { + name: "IPv6 Metadata Server URLs", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://[fd00:ec2::254]/latest/meta-data/placement/availability-zone", + URL: "http://[fd00:ec2::254]/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://[fd00:ec2::254]/latest/api/token", + }, + }, { + name: "Faulty RegionURL", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://abc.com/latest/meta-data/placement/availability-zone", + URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token", + }, + errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/placement/availability-zone for region_url", + }, { + name: "Faulty CredVerificationURL", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone", + URL: "http://abc.com/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token", + }, + errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/iam/security-credentials for url", + }, { + name: "Faulty IMDSv2SessionTokenURL", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone", + URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://abc.com/latest/api/token", + }, + errText: "oauth2/google: invalid hostname http://abc.com/latest/api/token for imdsv2_session_token_url", + }, + } + + for _, tt := range metadataServerValidityTests { + t.Run(tt.name, func(t *testing.T) { + tfc := testFileConfig + tfc.CredentialSource = tt.credSource + + oldGetenv := getenv + defer func() { getenv = oldGetenv }() + getenv = setEnvironment(map[string]string{}) + + _, err := tfc.parse(context.Background()) + if err != nil { + if tt.errText == "" { + t.Errorf("Didn't expect an error, but got %v", err) + } else if tt.errText != err.Error() { + t.Errorf("Expected %v, but got %v", tt.errText, err) + } + } else { + if tt.errText != "" { + t.Errorf("Expected error %v, but got none", tt.errText) + } + } + }) + } +} diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 9fc35535e..dcd252a61 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -67,22 +67,6 @@ type Config struct { // that include all elements in a given list, in that order. var ( - validTokenURLPatterns = []*regexp.Regexp{ - // The complicated part in the middle matches any number of characters that - // aren't period, spaces, or slashes. - regexp.MustCompile(`(?i)^[^\.\s\/\\]+\.sts\.googleapis\.com$`), - regexp.MustCompile(`(?i)^sts\.googleapis\.com$`), - regexp.MustCompile(`(?i)^sts\.[^\.\s\/\\]+\.googleapis\.com$`), - regexp.MustCompile(`(?i)^[^\.\s\/\\]+-sts\.googleapis\.com$`), - regexp.MustCompile(`(?i)^sts-[^\.\s\/\\]+\.p\.googleapis\.com$`), - } - validImpersonateURLPatterns = []*regexp.Regexp{ - regexp.MustCompile(`^[^\.\s\/\\]+\.iamcredentials\.googleapis\.com$`), - regexp.MustCompile(`^iamcredentials\.googleapis\.com$`), - regexp.MustCompile(`^iamcredentials\.[^\.\s\/\\]+\.googleapis\.com$`), - regexp.MustCompile(`^[^\.\s\/\\]+-iamcredentials\.googleapis\.com$`), - regexp.MustCompile(`^iamcredentials-[^\.\s\/\\]+\.p\.googleapis\.com$`), - } validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`) ) @@ -110,25 +94,13 @@ func validateWorkforceAudience(input string) bool { // TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials. func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { - return c.tokenSource(ctx, validTokenURLPatterns, validImpersonateURLPatterns, "https") + return c.tokenSource(ctx, "https") } // tokenSource is a private function that's directly called by some of the tests, // because the unit test URLs are mocked, and would otherwise fail the // validity check. -func (c *Config) tokenSource(ctx context.Context, tokenURLValidPats []*regexp.Regexp, impersonateURLValidPats []*regexp.Regexp, scheme string) (oauth2.TokenSource, error) { - valid := validateURL(c.TokenURL, tokenURLValidPats, scheme) - if !valid { - return nil, fmt.Errorf("oauth2/google: invalid TokenURL provided while constructing tokenSource") - } - - if c.ServiceAccountImpersonationURL != "" { - valid := validateURL(c.ServiceAccountImpersonationURL, impersonateURLValidPats, scheme) - if !valid { - return nil, fmt.Errorf("oauth2/google: invalid ServiceAccountImpersonationURL provided while constructing tokenSource") - } - } - +func (c *Config) tokenSource(ctx context.Context, scheme string) (oauth2.TokenSource, error) { if c.WorkforcePoolUserProject != "" { valid := validateWorkforceAudience(c.Audience) if !valid { @@ -213,6 +185,10 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL } + if err := awsCredSource.validateMetadataServers(); err != nil { + return nil, err + } + return awsCredSource, nil } } else if c.CredentialSource.File != "" { diff --git a/google/internal/externalaccount/basecredentials_test.go b/google/internal/externalaccount/basecredentials_test.go index 05e0127f0..bf6be321c 100644 --- a/google/internal/externalaccount/basecredentials_test.go +++ b/google/internal/externalaccount/basecredentials_test.go @@ -9,7 +9,6 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "strings" "testing" "time" @@ -208,140 +207,6 @@ func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) { } } -func TestValidateURLTokenURL(t *testing.T) { - var urlValidityTests = []struct { - tokURL string - expectSuccess bool - }{ - {"https://east.sts.googleapis.com", true}, - {"https://sts.googleapis.com", true}, - {"https://sts.asfeasfesef.googleapis.com", true}, - {"https://us-east-1-sts.googleapis.com", true}, - {"https://sts.googleapis.com/your/path/here", true}, - {"https://.sts.googleapis.com", false}, - {"https://badsts.googleapis.com", false}, - {"https://sts.asfe.asfesef.googleapis.com", false}, - {"https://sts..googleapis.com", false}, - {"https://-sts.googleapis.com", false}, - {"https://us-ea.st-1-sts.googleapis.com", false}, - {"https://sts.googleapis.com.evil.com/whatever/path", false}, - {"https://us-eas\\t-1.sts.googleapis.com", false}, - {"https:/us-ea/st-1.sts.googleapis.com", false}, - {"https:/us-east 1.sts.googleapis.com", false}, - {"https://", false}, - {"http://us-east-1.sts.googleapis.com", false}, - {"https://us-east-1.sts.googleapis.comevil.com", false}, - {"https://sts-xyz.p.googleapis.com", true}, - {"https://sts.pgoogleapis.com", false}, - {"https://p.googleapis.com", false}, - {"https://sts.p.com", false}, - {"http://sts.p.googleapis.com", false}, - {"https://xyz-sts.p.googleapis.com", false}, - {"https://sts-xyz.123.p.googleapis.com", false}, - {"https://sts-xyz.p1.googleapis.com", false}, - {"https://sts-xyz.p.foo.com", false}, - {"https://sts-xyz.p.foo.googleapis.com", false}, - } - ctx := context.Background() - for _, tt := range urlValidityTests { - t.Run(" "+tt.tokURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - config := testConfig - config.TokenURL = tt.tokURL - _, err := config.TokenSource(ctx) - - if tt.expectSuccess && err != nil { - t.Errorf("got %v but want nil", err) - } else if !tt.expectSuccess && err == nil { - t.Errorf("got nil but expected an error") - } - }) - } - for _, el := range urlValidityTests { - el.tokURL = strings.ToUpper(el.tokURL) - } - for _, tt := range urlValidityTests { - t.Run(" "+tt.tokURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - config := testConfig - config.TokenURL = tt.tokURL - _, err := config.TokenSource(ctx) - - if tt.expectSuccess && err != nil { - t.Errorf("got %v but want nil", err) - } else if !tt.expectSuccess && err == nil { - t.Errorf("got nil but expected an error") - } - }) - } -} - -func TestValidateURLImpersonateURL(t *testing.T) { - var urlValidityTests = []struct { - impURL string - expectSuccess bool - }{ - {"https://east.iamcredentials.googleapis.com", true}, - {"https://iamcredentials.googleapis.com", true}, - {"https://iamcredentials.asfeasfesef.googleapis.com", true}, - {"https://us-east-1-iamcredentials.googleapis.com", true}, - {"https://iamcredentials.googleapis.com/your/path/here", true}, - {"https://.iamcredentials.googleapis.com", false}, - {"https://badiamcredentials.googleapis.com", false}, - {"https://iamcredentials.asfe.asfesef.googleapis.com", false}, - {"https://iamcredentials..googleapis.com", false}, - {"https://-iamcredentials.googleapis.com", false}, - {"https://us-ea.st-1-iamcredentials.googleapis.com", false}, - {"https://iamcredentials.googleapis.com.evil.com/whatever/path", false}, - {"https://us-eas\\t-1.iamcredentials.googleapis.com", false}, - {"https:/us-ea/st-1.iamcredentials.googleapis.com", false}, - {"https:/us-east 1.iamcredentials.googleapis.com", false}, - {"https://", false}, - {"http://us-east-1.iamcredentials.googleapis.com", false}, - {"https://us-east-1.iamcredentials.googleapis.comevil.com", false}, - {"https://iamcredentials-xyz.p.googleapis.com", true}, - {"https://iamcredentials.pgoogleapis.com", false}, - {"https://p.googleapis.com", false}, - {"https://iamcredentials.p.com", false}, - {"http://iamcredentials.p.googleapis.com", false}, - {"https://xyz-iamcredentials.p.googleapis.com", false}, - {"https://iamcredentials-xyz.123.p.googleapis.com", false}, - {"https://iamcredentials-xyz.p1.googleapis.com", false}, - {"https://iamcredentials-xyz.p.foo.com", false}, - {"https://iamcredentials-xyz.p.foo.googleapis.com", false}, - } - ctx := context.Background() - for _, tt := range urlValidityTests { - t.Run(" "+tt.impURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - config := testConfig - config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL - config.ServiceAccountImpersonationURL = tt.impURL - _, err := config.TokenSource(ctx) - - if tt.expectSuccess && err != nil { - t.Errorf("got %v but want nil", err) - } else if !tt.expectSuccess && err == nil { - t.Errorf("got nil but expected an error") - } - }) - } - for _, el := range urlValidityTests { - el.impURL = strings.ToUpper(el.impURL) - } - for _, tt := range urlValidityTests { - t.Run(" "+tt.impURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - config := testConfig - config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL - config.ServiceAccountImpersonationURL = tt.impURL - _, err := config.TokenSource(ctx) - - if tt.expectSuccess && err != nil { - t.Errorf("got %v but want nil", err) - } else if !tt.expectSuccess && err == nil { - t.Errorf("got nil but expected an error") - } - }) - } -} - func TestWorkforcePoolCreation(t *testing.T) { var audienceValidatyTests = []struct { audience string diff --git a/google/internal/externalaccount/impersonate_test.go b/google/internal/externalaccount/impersonate_test.go index 17e2f6d72..8c7f6a9a7 100644 --- a/google/internal/externalaccount/impersonate_test.go +++ b/google/internal/externalaccount/impersonate_test.go @@ -9,7 +9,6 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "regexp" "testing" ) @@ -114,8 +113,7 @@ func TestImpersonation(t *testing.T) { defer targetServer.Close() testImpersonateConfig.TokenURL = targetServer.URL - allURLs := regexp.MustCompile(".+") - ourTS, err := testImpersonateConfig.tokenSource(context.Background(), []*regexp.Regexp{allURLs}, []*regexp.Regexp{allURLs}, "http") + ourTS, err := testImpersonateConfig.tokenSource(context.Background(), "http") if err != nil { t.Fatalf("Failed to create TokenSource: %v", err) } diff --git a/oauth2.go b/oauth2.go index 5cb7ca175..0fa61292d 100644 --- a/oauth2.go +++ b/oauth2.go @@ -16,6 +16,7 @@ import ( "net/url" "strings" "sync" + "time" "golang.org/x/oauth2/advancedauth" "golang.org/x/oauth2/internal" @@ -316,6 +317,8 @@ type reuseTokenSource struct { mu sync.Mutex // guards t t *Token + + expiryDelta time.Duration } // Token returns the current token if it's still valid, else will @@ -331,6 +334,7 @@ func (s *reuseTokenSource) Token() (*Token, error) { if err != nil { return nil, err } + t.expiryDelta = s.expiryDelta s.t = t return t, nil } @@ -405,3 +409,30 @@ func ReuseTokenSource(t *Token, src TokenSource) TokenSource { new: src, } } + +// ReuseTokenSource returns a TokenSource that acts in the same manner as the +// TokenSource returned by ReuseTokenSource, except the expiry buffer is +// configurable. The expiration time of a token is calculated as +// t.Expiry.Add(-earlyExpiry). +func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource { + // Don't wrap a reuseTokenSource in itself. That would work, + // but cause an unnecessary number of mutex operations. + // Just build the equivalent one. + if rt, ok := src.(*reuseTokenSource); ok { + if t == nil { + // Just use it directly, but set the expiryDelta to earlyExpiry, + // so the behavior matches what the user expects. + rt.expiryDelta = earlyExpiry + return rt + } + src = rt.new + } + if t != nil { + t.expiryDelta = earlyExpiry + } + return &reuseTokenSource{ + t: t, + new: src, + expiryDelta: earlyExpiry, + } +} diff --git a/token.go b/token.go index 822720341..55a1fedca 100644 --- a/token.go +++ b/token.go @@ -21,6 +21,11 @@ import ( // expirations due to client-server time mismatches. const expiryDelta = 10 * time.Second +// defaultExpiryDelta determines how earlier a token should be considered +// expired than its actual expiration time. It is used to avoid late +// expirations due to client-server time mismatches. +const defaultExpiryDelta = 10 * time.Second + // Token represents the credentials used to authorize // the requests to access protected resources on the OAuth 2.0 // provider's backend. @@ -52,6 +57,11 @@ type Token struct { // raw optionally contains extra metadata from the server // when updating a token. raw interface{} + + // expiryDelta is used to calculate when a token is considered + // expired, by subtracting from Expiry. If zero, defaultExpiryDelta + // is used. + expiryDelta time.Duration } // Type returns t.TokenType if non-empty, else "Bearer".