From 13117402d895f69ab171baadcd1cbfed778b0282 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 23 Dec 2022 10:26:49 +0700 Subject: [PATCH] Attempt to login and retry request when access token was expired --- pkg/client/client.go | 41 +++++++++++++++++++++++----- pkg/client/client_test.go | 56 ++++++++++++++++++++++++++++++++++++++- pkg/client/helper.go | 51 +++++++++++++++++++++++++---------- pkg/client/helper_test.go | 6 +++-- 4 files changed, 130 insertions(+), 24 deletions(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index 4a19837..706798a 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -17,7 +17,7 @@ type Config struct { type Client struct { user *loginParams baseURL string - accessToken string + accessToken *cString } const ( @@ -39,7 +39,8 @@ func NewClient(cfg *Config) (*Client, error) { Username: cfg.Username, Password: cfg.Password, }, - baseURL: fmt.Sprintf("%s/%s/v1/", cfg.Address, contextPath), + baseURL: fmt.Sprintf("%s/%s/v1/", cfg.Address, contextPath), + accessToken: &cString{}, } if err := client.login(); err != nil { @@ -52,7 +53,7 @@ func NewClient(cfg *Config) (*Client, error) { func (c *Client) login() error { var resp loginResponse - err := request( + err := c.request( context.Background(), http.MethodPost, c.baseURL+LoginPath, &resp, withForm( "username", c.user.Username, @@ -61,13 +62,39 @@ func (c *Client) login() error { return err } - c.accessToken = resp.AccessToken + c.accessToken.set(resp.AccessToken) return nil } +func (c *Client) request(ctx context.Context, method, url string, result interface{}, opts ...requestOptionFn) error { + _request := func() error { + req, err := newRequest(ctx, method, url, opts...) + if err != nil { + return fmt.Errorf("failed to create new request: %v", err) + } + + err = sendRequest(req, result) + if err != nil { + return fmt.Errorf("failed to send request = %v: %v", *req, err) + } + + return err + } + + err := _request() + if isTokenExpiredError(err) { + if loginErr := c.login(); loginErr != nil { + return fmt.Errorf("token expired %s, re-login attempt failed: err = %w ", err, loginErr) + } + err = _request() + } + + return err +} + func (c *Client) GetConfiguration(ctx context.Context, params *ConfigurationId) (*Configuration, error) { var resp Configuration - err := request( + err := c.request( ctx, http.MethodGet, c.baseURL+ConfigurationPath, &resp, withAuthentication(c.accessToken), withQuery( @@ -88,7 +115,7 @@ func (c *Client) GetConfiguration(ctx context.Context, params *ConfigurationId) func (c *Client) PublishConfiguration(ctx context.Context, params *Configuration) error { var resp bool - err := request( + err := c.request( ctx, http.MethodPost, c.baseURL+ConfigurationPath, &resp, withAuthentication(c.accessToken), withForm( @@ -106,7 +133,7 @@ func (c *Client) PublishConfiguration(ctx context.Context, params *Configuration func (c *Client) DeleteConfiguration(ctx context.Context, params *ConfigurationId) (bool, error) { var resp bool - err := request( + err := c.request( ctx, http.MethodDelete, c.baseURL+ConfigurationPath, &resp, withAuthentication(c.accessToken), withQuery( diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 4e6a586..d9124e8 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -27,6 +27,22 @@ func defaultLoginHandler(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write(jsonResp) } +func injectResponseToHandler(statusCode int, resp interface{}, handler http.HandlerFunc) http.HandlerFunc { + injectResponse := true + return func(w http.ResponseWriter, r *http.Request) { + if !injectResponse { + handler(w, r) + return + } + + jsonResp, _ := json.Marshal(resp) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + _, _ = w.Write(jsonResp) + injectResponse = false + } +} + func TestNewClient(t *testing.T) { tests := []struct { name string @@ -70,7 +86,7 @@ func TestNewClient(t *testing.T) { }) if tt.expectErr == nil { assert.Nil(t, err) - assert.Equal(t, _AccessToken, client.accessToken) + assert.Equal(t, _AccessToken, client.accessToken.value()) } else { assert.NotNil(t, err) } @@ -114,6 +130,44 @@ func TestClient_GetConfiguration(t *testing.T) { }, expectErr: nil, }, + { + name: "failed and no retry", + getConfigHandler: injectResponseToHandler( + 500, map[string]interface{}{ + "status": "500", + "error": "Internal", + "message": "server internal error!", + }, func(w http.ResponseWriter, _ *http.Request) { + jsonResp, _ := json.Marshal(map[string]interface{}{ + "tenant": "namespace", + "group": "GROUP", + "dataId": "key", + "content": "value", + }) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(jsonResp) + }), + expectErr: nil, + }, + { + name: "success after re-login", + getConfigHandler: injectResponseToHandler( + 403, map[string]interface{}{ + "status": "403", + "error": "Forbidden", + "message": "token expired!", + }, func(w http.ResponseWriter, _ *http.Request) { + jsonResp, _ := json.Marshal(map[string]interface{}{ + "tenant": "namespace", + "group": "GROUP", + "dataId": "key", + "content": "value", + }) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(jsonResp) + }), + expectErr: nil, + }, } for _, tt := range tests { diff --git a/pkg/client/helper.go b/pkg/client/helper.go index 69eaaf9..bfea48b 100644 --- a/pkg/client/helper.go +++ b/pkg/client/helper.go @@ -7,7 +7,9 @@ import ( "io" "net/http" "net/url" + "regexp" "strings" + "sync" ) type requestOption struct { @@ -43,8 +45,14 @@ func withQuery(kv ...string) requestOptionFn { } } -func withAuthentication(token string) requestOptionFn { - return withQuery(accessTokenQueryName, token) +func withAuthentication(token *cString) requestOptionFn { + return func(rOpts *requestOption) error { + if rOpts.query == nil { + rOpts.query = &url.Values{} + } + + return updateValues("query string", rOpts.query, accessTokenQueryName, token.value()) + } } func withForm(kv ...string) requestOptionFn { @@ -91,6 +99,19 @@ func newRequest(ctx context.Context, method, url string, opts ...requestOptionFn return req, nil } +func isTokenExpiredError(err error) bool { + if err == nil { + return false + } + re := regexp.MustCompile(`request error status_code = (\d*), body = (.*)`) + matches := re.FindAllStringSubmatch(err.Error(), -1) + if matches == nil || len(matches[0]) != 3 { + return false + } + statusCode, body := matches[0][1], matches[0][2] + return statusCode == "403" && strings.Contains(body, `"message":"token expired!"`) +} + func sendRequest(req *http.Request, result interface{}) error { var err error resp, err := http.DefaultClient.Do(req) @@ -119,18 +140,20 @@ func sendRequest(req *http.Request, result interface{}) error { return nil } -func request(ctx context.Context, method, url string, result interface{}, opts ...requestOptionFn) error { - var err error - - req, err := newRequest(ctx, method, url, opts...) - if err != nil { - return fmt.Errorf("failed to create new request: %v", err) - } +// cString is a concurrent safe string +type cString struct { + mux sync.RWMutex + v string +} - err = sendRequest(req, result) - if err != nil { - return fmt.Errorf("failed to send request = %v: %v", *req, err) - } +func (ts *cString) value() string { + ts.mux.RLock() + defer ts.mux.RUnlock() + return ts.v +} - return nil +func (ts *cString) set(s string) { + ts.mux.Lock() + defer ts.mux.Unlock() + ts.v = s } diff --git a/pkg/client/helper_test.go b/pkg/client/helper_test.go index a59d94a..931c01d 100644 --- a/pkg/client/helper_test.go +++ b/pkg/client/helper_test.go @@ -3,9 +3,10 @@ package client import ( "context" "fmt" - "github.com/stretchr/testify/assert" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestNewRequest(t *testing.T) { @@ -29,7 +30,8 @@ func TestNewRequest(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - err := request( + cli := new(Client) + err := cli.request( context.Background(), "GET", "/test", nil, withForm(tc.form...), withQuery(tc.query...)) assert.True(t, strings.Contains(err.Error(), tc.err.Error()))