From 9042551d75f6094587a91020fe5277d01997ce4d Mon Sep 17 00:00:00 2001 From: Shawn Neal Date: Fri, 27 Oct 2023 09:11:33 -0700 Subject: [PATCH] Add request context usage for oauth token renewals - Fix issue 325 --- client/client.go | 6 ++--- internal/http/client_provider.go | 13 ++++++----- internal/http/executor.go | 9 ++++---- internal/http/oauth_session_manager.go | 12 +++++----- internal/http/oauth_session_manager_test.go | 25 +++++++++++---------- 5 files changed, 35 insertions(+), 30 deletions(-) diff --git a/client/client.go b/client/client.go index 24fd60e2..e4b1f0d4 100644 --- a/client/client.go +++ b/client/client.go @@ -144,8 +144,8 @@ func New(config *config.Config) (*Client, error) { } // AccessToken returns the raw encoded OAuth access token without the bearer prefix -func (c *Client) AccessToken(ignoredCtx context.Context) (string, error) { - token, err := c.authenticatedClientProvider.AccessToken() +func (c *Client) AccessToken(ctx context.Context) (string, error) { + token, err := c.authenticatedClientProvider.AccessToken(ctx) if err != nil { return "", err } @@ -164,7 +164,7 @@ func (c *Client) SSHCode(ctx context.Context) (string, error) { values.Set("response_type", "code") values.Set("client_id", r.Links.AppSSH.Meta.OauthClient) // client_id,used by cf server - token, err := c.authenticatedClientProvider.AccessToken() + token, err := c.authenticatedClientProvider.AccessToken(ctx) if err != nil { return "", err } diff --git a/internal/http/client_provider.go b/internal/http/client_provider.go index d73c7cc5..b6409876 100644 --- a/internal/http/client_provider.go +++ b/internal/http/client_provider.go @@ -1,13 +1,16 @@ package http -import "net/http" +import ( + "context" + "net/http" +) type ClientProvider interface { // Client returns a *http.Client - Client(followRedirects bool) (*http.Client, error) + Client(ctx context.Context, followRedirects bool) (*http.Client, error) // ReAuthenticate tells the provider to re-initialize the auth context - ReAuthenticate() error + ReAuthenticate(ctx context.Context) error } type UnauthenticatedClientProvider struct { @@ -15,14 +18,14 @@ type UnauthenticatedClientProvider struct { httpClientNonRedirecting *http.Client } -func (c *UnauthenticatedClientProvider) Client(followRedirects bool) (*http.Client, error) { +func (c *UnauthenticatedClientProvider) Client(ctx context.Context, followRedirects bool) (*http.Client, error) { if followRedirects { return c.httpClient, nil } return c.httpClientNonRedirecting, nil } -func (c *UnauthenticatedClientProvider) ReAuthenticate() error { +func (c *UnauthenticatedClientProvider) ReAuthenticate(ctx context.Context) error { return nil } diff --git a/internal/http/executor.go b/internal/http/executor.go index 9e51c7bc..65c62ee8 100644 --- a/internal/http/executor.go +++ b/internal/http/executor.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -61,7 +62,7 @@ func (c *Executor) ExecuteRequest(request *Request) (*http.Response, error) { // refresh token is expired or revoked. Attempt to get a new refresh and access token and retry the request. var authErr *unauthorizedError if errors.As(err, &authErr) { - err = c.reAuthenticate() + err = c.reAuthenticate(req.Context()) if err != nil { return nil, err } @@ -111,7 +112,7 @@ func (c *Executor) newHTTPRequest(request *Request) (*http.Request, error) { // do will get the proper http.Client and calls Do on it using the specified http.Request func (c *Executor) do(request *http.Request, followRedirects bool) (*http.Response, error) { - client, err := c.clientProvider.Client(followRedirects) + client, err := c.clientProvider.Client(request.Context(), followRedirects) if err != nil { return nil, fmt.Errorf("error executing request, failed to get the underlying HTTP client: %w", err) } @@ -148,8 +149,8 @@ func (c *Executor) do(request *http.Request, followRedirects bool) (*http.Respon } // reAuthenticate tells the client provider to restart authentication anew because we received a 401 -func (c *Executor) reAuthenticate() error { - err := c.clientProvider.ReAuthenticate() +func (c *Executor) reAuthenticate(ctx context.Context) error { + err := c.clientProvider.ReAuthenticate(ctx) if err != nil { return fmt.Errorf("an error occurred attempting to reauthenticate "+ "after initially receiving a 401 executing a request: %w", err) diff --git a/internal/http/oauth_session_manager.go b/internal/http/oauth_session_manager.go index 7d0822f5..3716cb24 100644 --- a/internal/http/oauth_session_manager.go +++ b/internal/http/oauth_session_manager.go @@ -36,8 +36,8 @@ func NewOAuthSessionManager(config *config.Config) *OAuthSessionManager { } // Client returns an authenticated OAuth http client -func (m *OAuthSessionManager) Client(followRedirects bool) (*http.Client, error) { - err := m.init(context.Background()) +func (m *OAuthSessionManager) Client(ctx context.Context, followRedirects bool) (*http.Client, error) { + err := m.init(ctx) if err != nil { return nil, err } @@ -55,7 +55,7 @@ func (m *OAuthSessionManager) Client(followRedirects bool) (*http.Client, error) // likely in response to a 401 // // This won't work for userTokenAuth since we have no credentials to exchange for a new token. -func (m *OAuthSessionManager) ReAuthenticate() error { +func (m *OAuthSessionManager) ReAuthenticate(ctx context.Context) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -64,12 +64,12 @@ func (m *OAuthSessionManager) ReAuthenticate() error { } // attempt to create a new token source - return m.newTokenSource(context.Background()) + return m.newTokenSource(ctx) } // AccessToken returns the raw OAuth access token -func (m *OAuthSessionManager) AccessToken() (string, error) { - err := m.init(context.Background()) +func (m *OAuthSessionManager) AccessToken(ctx context.Context) (string, error) { + err := m.init(ctx) if err != nil { return "", err } diff --git a/internal/http/oauth_session_manager_test.go b/internal/http/oauth_session_manager_test.go index d7e29cd7..f257541a 100644 --- a/internal/http/oauth_session_manager_test.go +++ b/internal/http/oauth_session_manager_test.go @@ -1,6 +1,7 @@ package http_test import ( + "context" "github.com/cloudfoundry-community/go-cfclient/v3/config" "github.com/cloudfoundry-community/go-cfclient/v3/internal/http" "github.com/cloudfoundry-community/go-cfclient/v3/testutil" @@ -22,7 +23,7 @@ func TestOAuthSessionManager(t *testing.T) { require.Empty(t, c.UAAEndpointURL) m := http.NewOAuthSessionManager(c) - _, err = m.Client(true) + _, err = m.Client(context.Background(), true) require.Error(t, err, "expected an error when UAA or Login endpoint is empty") require.Equal(t, "login and UAA endpoints must not be empty", err.Error()) @@ -31,35 +32,35 @@ func TestOAuthSessionManager(t *testing.T) { c.UAAEndpointURL = uaaURL // we can create a client that utilizes oauth - client1, err := m.Client(true) + client1, err := m.Client(context.Background(), true) require.NoError(t, err) require.NotNil(t, client1) // the same access token is returned as long as it's not expired (which it's not - 300s) - token, err := m.AccessToken() + token, err := m.AccessToken(context.Background()) require.NoError(t, err) require.Equal(t, "foobar1", token) require.NoError(t, err) - token, err = m.AccessToken() + token, err = m.AccessToken(context.Background()) require.NoError(t, err) require.Equal(t, "foobar1", token) // the same client is returned - client2, err := m.Client(true) + client2, err := m.Client(context.Background(), true) require.NoError(t, err) require.Same(t, client1, client2) // we force new auth context - err = m.ReAuthenticate() + err = m.ReAuthenticate(context.Background()) require.NoError(t, err) // a different client is now returned - client3, err := m.Client(true) + client3, err := m.Client(context.Background(), true) require.NoError(t, err) require.NotSame(t, client2, client3) // a new token is also returned - token, err = m.AccessToken() + token, err = m.AccessToken(context.Background()) require.NoError(t, err) require.Equal(t, "foobar2", token) @@ -84,22 +85,22 @@ func TestOAuthSessionManagerRefreshToken(t *testing.T) { m := http.NewOAuthSessionManager(c) // we can create a client that utilizes oauth - client1, err := m.Client(true) + client1, err := m.Client(context.Background(), true) require.NoError(t, err) require.NotNil(t, client1) // get the access token, it should have been auto-refreshed because the one we gave in the config was expired - token, err := m.AccessToken() + token, err := m.AccessToken(context.Background()) require.NoError(t, err) require.NotEqual(t, accessToken, token) require.Equal(t, "foobar1", token) // get the access token, should be the same as before - token, err = m.AccessToken() + token, err = m.AccessToken(context.Background()) require.NoError(t, err) require.Equal(t, "foobar1", token) // we cannot re-auth with only a refresh token (no credentials) - err = m.ReAuthenticate() + err = m.ReAuthenticate(context.Background()) require.EqualError(t, err, "cannot reauthenticate user token auth type, check your access and/or refresh token expiration date") }