Skip to content

Commit

Permalink
Merge pull request #3 from zalopay-oss/bug-fix/token-expired
Browse files Browse the repository at this point in the history
Attempt to login and retry request when access token was expired
  • Loading branch information
tungluu18 authored Dec 23, 2022
2 parents db8d324 + 1311740 commit 6d0de25
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 24 deletions.
41 changes: 34 additions & 7 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Config struct {
type Client struct {
user *loginParams
baseURL string
accessToken string
accessToken *cString
}

const (
Expand All @@ -39,7 +39,8 @@ func NewClient(cfg *Config) (*Client, error) {
Username: cfg.Username,
Password: cfg.Password,
},
baseURL: fmt.Sprintf("%s/%s/v1/", cfg.Address, contextPath),
baseURL: fmt.Sprintf("%s/%s/v1/", cfg.Address, contextPath),
accessToken: &cString{},
}

if err := client.login(); err != nil {
Expand All @@ -52,7 +53,7 @@ func NewClient(cfg *Config) (*Client, error) {

func (c *Client) login() error {
var resp loginResponse
err := request(
err := c.request(
context.Background(), http.MethodPost, c.baseURL+LoginPath, &resp,
withForm(
"username", c.user.Username,
Expand All @@ -61,13 +62,39 @@ func (c *Client) login() error {
return err
}

c.accessToken = resp.AccessToken
c.accessToken.set(resp.AccessToken)
return nil
}

func (c *Client) request(ctx context.Context, method, url string, result interface{}, opts ...requestOptionFn) error {
_request := func() error {
req, err := newRequest(ctx, method, url, opts...)
if err != nil {
return fmt.Errorf("failed to create new request: %v", err)
}

err = sendRequest(req, result)
if err != nil {
return fmt.Errorf("failed to send request = %v: %v", *req, err)
}

return err
}

err := _request()
if isTokenExpiredError(err) {
if loginErr := c.login(); loginErr != nil {
return fmt.Errorf("token expired %s, re-login attempt failed: err = %w ", err, loginErr)
}
err = _request()
}

return err
}

func (c *Client) GetConfiguration(ctx context.Context, params *ConfigurationId) (*Configuration, error) {
var resp Configuration
err := request(
err := c.request(
ctx, http.MethodGet, c.baseURL+ConfigurationPath, &resp,
withAuthentication(c.accessToken),
withQuery(
Expand All @@ -88,7 +115,7 @@ func (c *Client) GetConfiguration(ctx context.Context, params *ConfigurationId)

func (c *Client) PublishConfiguration(ctx context.Context, params *Configuration) error {
var resp bool
err := request(
err := c.request(
ctx, http.MethodPost, c.baseURL+ConfigurationPath, &resp,
withAuthentication(c.accessToken),
withForm(
Expand All @@ -106,7 +133,7 @@ func (c *Client) PublishConfiguration(ctx context.Context, params *Configuration

func (c *Client) DeleteConfiguration(ctx context.Context, params *ConfigurationId) (bool, error) {
var resp bool
err := request(
err := c.request(
ctx, http.MethodDelete, c.baseURL+ConfigurationPath, &resp,
withAuthentication(c.accessToken),
withQuery(
Expand Down
56 changes: 55 additions & 1 deletion pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ func defaultLoginHandler(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write(jsonResp)
}

func injectResponseToHandler(statusCode int, resp interface{}, handler http.HandlerFunc) http.HandlerFunc {
injectResponse := true
return func(w http.ResponseWriter, r *http.Request) {
if !injectResponse {
handler(w, r)
return
}

jsonResp, _ := json.Marshal(resp)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
_, _ = w.Write(jsonResp)
injectResponse = false
}
}

func TestNewClient(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -70,7 +86,7 @@ func TestNewClient(t *testing.T) {
})
if tt.expectErr == nil {
assert.Nil(t, err)
assert.Equal(t, _AccessToken, client.accessToken)
assert.Equal(t, _AccessToken, client.accessToken.value())
} else {
assert.NotNil(t, err)
}
Expand Down Expand Up @@ -114,6 +130,44 @@ func TestClient_GetConfiguration(t *testing.T) {
},
expectErr: nil,
},
{
name: "failed and no retry",
getConfigHandler: injectResponseToHandler(
500, map[string]interface{}{
"status": "500",
"error": "Internal",
"message": "server internal error!",
}, func(w http.ResponseWriter, _ *http.Request) {
jsonResp, _ := json.Marshal(map[string]interface{}{
"tenant": "namespace",
"group": "GROUP",
"dataId": "key",
"content": "value",
})
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jsonResp)
}),
expectErr: nil,
},
{
name: "success after re-login",
getConfigHandler: injectResponseToHandler(
403, map[string]interface{}{
"status": "403",
"error": "Forbidden",
"message": "token expired!",
}, func(w http.ResponseWriter, _ *http.Request) {
jsonResp, _ := json.Marshal(map[string]interface{}{
"tenant": "namespace",
"group": "GROUP",
"dataId": "key",
"content": "value",
})
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jsonResp)
}),
expectErr: nil,
},
}

for _, tt := range tests {
Expand Down
51 changes: 37 additions & 14 deletions pkg/client/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"io"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
)

type requestOption struct {
Expand Down Expand Up @@ -43,8 +45,14 @@ func withQuery(kv ...string) requestOptionFn {
}
}

func withAuthentication(token string) requestOptionFn {
return withQuery(accessTokenQueryName, token)
func withAuthentication(token *cString) requestOptionFn {
return func(rOpts *requestOption) error {
if rOpts.query == nil {
rOpts.query = &url.Values{}
}

return updateValues("query string", rOpts.query, accessTokenQueryName, token.value())
}
}

func withForm(kv ...string) requestOptionFn {
Expand Down Expand Up @@ -91,6 +99,19 @@ func newRequest(ctx context.Context, method, url string, opts ...requestOptionFn
return req, nil
}

func isTokenExpiredError(err error) bool {
if err == nil {
return false
}
re := regexp.MustCompile(`request error status_code = (\d*), body = (.*)`)
matches := re.FindAllStringSubmatch(err.Error(), -1)
if matches == nil || len(matches[0]) != 3 {
return false
}
statusCode, body := matches[0][1], matches[0][2]
return statusCode == "403" && strings.Contains(body, `"message":"token expired!"`)
}

func sendRequest(req *http.Request, result interface{}) error {
var err error
resp, err := http.DefaultClient.Do(req)
Expand Down Expand Up @@ -119,18 +140,20 @@ func sendRequest(req *http.Request, result interface{}) error {
return nil
}

func request(ctx context.Context, method, url string, result interface{}, opts ...requestOptionFn) error {
var err error

req, err := newRequest(ctx, method, url, opts...)
if err != nil {
return fmt.Errorf("failed to create new request: %v", err)
}
// cString is a concurrent safe string
type cString struct {
mux sync.RWMutex
v string
}

err = sendRequest(req, result)
if err != nil {
return fmt.Errorf("failed to send request = %v: %v", *req, err)
}
func (ts *cString) value() string {
ts.mux.RLock()
defer ts.mux.RUnlock()
return ts.v
}

return nil
func (ts *cString) set(s string) {
ts.mux.Lock()
defer ts.mux.Unlock()
ts.v = s
}
6 changes: 4 additions & 2 deletions pkg/client/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package client
import (
"context"
"fmt"
"github.com/stretchr/testify/assert"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestNewRequest(t *testing.T) {
Expand All @@ -29,7 +30,8 @@ func TestNewRequest(t *testing.T) {

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
err := request(
cli := new(Client)
err := cli.request(
context.Background(), "GET", "/test", nil,
withForm(tc.form...), withQuery(tc.query...))
assert.True(t, strings.Contains(err.Error(), tc.err.Error()))
Expand Down

0 comments on commit 6d0de25

Please sign in to comment.