diff --git a/gateway/api.go b/gateway/api.go index edcfdea93c8..422ff8d4a60 100644 --- a/gateway/api.go +++ b/gateway/api.go @@ -34,7 +34,6 @@ import ( "fmt" "io/ioutil" "net/http" - "net/url" "os" "path/filepath" "strconv" @@ -51,6 +50,7 @@ import ( "github.com/TykTechnologies/tyk/apidef" "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/headers" + "github.com/TykTechnologies/tyk/internal/url" "github.com/TykTechnologies/tyk/storage" "github.com/TykTechnologies/tyk/user" @@ -59,6 +59,14 @@ import ( "github.com/TykTechnologies/tyk/internal/uuid" ) +const ( + oAuthClientTokensKeyPattern = "oauth-data.*oauth-client-tokens.*" +) + +var ( + ErrRequestMalformed = errors.New("request malformed") +) + // apiModifyKeySuccess represents when a Key modification was successful // // swagger:model apiModifyKeySuccess @@ -2044,6 +2052,26 @@ func (gw *Gateway) getOauthClientDetails(keyName, apiID string) (interface{}, in return reportableClientData, http.StatusOK } +func (gw *Gateway) oAuthTokensHandler(w http.ResponseWriter, r *http.Request) { + if !url.QueryHas(r.URL.Query(), "scope") { + doJSONWrite(w, http.StatusUnprocessableEntity, apiError("scope parameter is required")) + return + } + + if r.URL.Query().Get("scope") != "lapsed" { + doJSONWrite(w, http.StatusBadRequest, apiError("unknown scope")) + return + } + + err := gw.purgeLapsedOAuthTokens() + if err != nil { + doJSONWrite(w, http.StatusInternalServerError, apiError("error purging lapsed tokens")) + return + } + + doJSONWrite(w, http.StatusOK, apiOk("lapsed tokens purged")) +} + // Delete Client func (gw *Gateway) handleDeleteOAuthClient(keyName, apiID string) (interface{}, int) { storageID := oauthClientStorageID(keyName) diff --git a/gateway/api_test.go b/gateway/api_test.go index fc508086d13..a6a52bd616d 100644 --- a/gateway/api_test.go +++ b/gateway/api_test.go @@ -2042,3 +2042,71 @@ func TestOrgKeyHandler_LastUpdated(t *testing.T) { }}, }...) } + +func TestPurgeOAuthClientTokens(t *testing.T) { + conf := func(globalConf *config.Config) { + // set tokens to be expired after 1 second + globalConf.OauthTokenExpire = 1 + // cleanup tokens older than 2 seconds + globalConf.OauthTokenExpiredRetainPeriod = 2 + } + + ts := StartTest(conf) + defer ts.Close() + + t.Run("scope validation", func(t *testing.T) { + ts.Run(t, []test.TestCase{ + { + AdminAuth: true, + Path: "/tyk/oauth/tokens/", + Method: http.MethodDelete, + Code: http.StatusUnprocessableEntity, + }, + { + AdminAuth: true, + Path: "/tyk/oauth/tokens/", + QueryParams: map[string]string{"scope": "expired"}, + Method: http.MethodDelete, + Code: http.StatusBadRequest, + }, + }...) + }) + + assertTokensLen := func(t *testing.T, storageManager storage.Handler, storageKey string, expectedTokensLen int) { + nowTs := time.Now().Unix() + startScore := strconv.FormatInt(nowTs, 10) + tokens, _, err := storageManager.GetSortedSetRange(storageKey, startScore, "+inf") + assert.NoError(t, err) + assert.Equal(t, expectedTokensLen, len(tokens)) + } + + t.Run("scope=lapsed", func(t *testing.T) { + spec := ts.LoadTestOAuthSpec() + + clientID1, clientID2 := uuid.New(), uuid.New() + + ts.createOAuthClientIDAndTokens(t, spec, clientID1) + ts.createOAuthClientIDAndTokens(t, spec, clientID2) + storageKey1, storageKey2 := fmt.Sprintf("%s%s", prefixClientTokens, clientID1), + fmt.Sprintf("%s%s", prefixClientTokens, clientID2) + + storageManager := ts.Gw.getGlobalStorageHandler(generateOAuthPrefix(spec.APIID), false) + storageManager.Connect() + + assertTokensLen(t, storageManager, storageKey1, 3) + assertTokensLen(t, storageManager, storageKey2, 3) + + time.Sleep(time.Second * 3) + ts.Run(t, test.TestCase{ + ControlRequest: true, + AdminAuth: true, + Path: "/tyk/oauth/tokens", + QueryParams: map[string]string{"scope": "lapsed"}, + Method: http.MethodDelete, + Code: http.StatusOK, + }) + + assertTokensLen(t, storageManager, storageKey1, 0) + assertTokensLen(t, storageManager, storageKey2, 0) + }) +} diff --git a/gateway/oauth_manager.go b/gateway/oauth_manager.go index 84e66e94c9c..90a870c3246 100644 --- a/gateway/oauth_manager.go +++ b/gateway/oauth_manager.go @@ -10,19 +10,21 @@ import ( "net/http" "net/url" "strings" + "sync" "time" "github.com/TykTechnologies/tyk/request" - "github.com/sirupsen/logrus" - + "github.com/hashicorp/go-multierror" "github.com/lonelycode/osin" + "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" - "github.com/TykTechnologies/tyk/internal/uuid" - "strconv" + "github.com/TykTechnologies/tyk/internal/uuid" + "github.com/TykTechnologies/tyk/headers" + tykerrors "github.com/TykTechnologies/tyk/internal/errors" "github.com/TykTechnologies/tyk/storage" "github.com/TykTechnologies/tyk/user" ) @@ -1186,3 +1188,48 @@ func (r *RedisOsinStorageInterface) SetUser(username string, session *user.Sessi return nil } + +func (gw *Gateway) purgeLapsedOAuthTokens() error { + if gw.GetConfig().OauthTokenExpiredRetainPeriod <= 0 { + return nil + } + + redisCluster := &storage.RedisCluster{KeyPrefix: "", HashKeys: false, RedisController: gw.RedisController} + keys, err := redisCluster.ScanKeys(oAuthClientTokensKeyPattern) + + if err != nil { + log.WithError(err).Debug("error while scanning for tokens") + return err + } + + nowTs := time.Now().Unix() + // clean up expired tokens in sorted set (remove all tokens with score up to current timestamp minus retention) + cleanupStartScore := strconv.FormatInt(nowTs-int64(gw.GetConfig().OauthTokenExpiredRetainPeriod), 10) + + var wg sync.WaitGroup + + errs := make(chan error, len(keys)) + for _, key := range keys { + wg.Add(1) + go func(k string) { + defer wg.Done() + if err := redisCluster.RemoveSortedSetRange(k, "-inf", cleanupStartScore); err != nil { + errs <- err + } + }(key) + } + + // Wait for all goroutines to finish + wg.Wait() + close(errs) + + combinedErr := &multierror.Error{ + ErrorFormat: tykerrors.Formatter, + } + + for err := range errs { + combinedErr = multierror.Append(combinedErr, err) + } + + return combinedErr.ErrorOrNil() +} diff --git a/gateway/oauth_manager_test.go b/gateway/oauth_manager_test.go index 5ef91f54014..69fee07a1a6 100644 --- a/gateway/oauth_manager_test.go +++ b/gateway/oauth_manager_test.go @@ -8,7 +8,9 @@ import ( "bytes" "encoding/json" "net/url" + "path" "reflect" + "strconv" "strings" "testing" @@ -150,6 +152,41 @@ func (ts *Test) createTestOAuthClient(spec *APISpec, clientID string) OAuthClien return testClient } +func (ts *Test) createOAuthClientIDAndTokens(t *testing.T, spec *APISpec, clientID string) { + t.Helper() + ts.createTestOAuthClient(spec, clientID) + + param := make(url.Values) + param.Set("response_type", "token") + param.Set("redirect_uri", authRedirectUri) + param.Set("client_id", clientID) + param.Set("client_secret", authClientSecret) + param.Set("key_rules", keyRules) + + headers := map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + } + + for i := 0; i < 3; i++ { + resp, err := ts.Run(t, test.TestCase{ + Path: path.Join(spec.Proxy.ListenPath, "/tyk/oauth/authorize-client/"), + Data: param.Encode(), + AdminAuth: true, + Headers: headers, + Method: http.MethodPost, + Code: http.StatusOK, + }) + if err != nil { + t.Error(err) + } + + response := map[string]interface{}{} + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + t.Fatal(err) + } + } +} + func TestOauthMultipleAPIs(t *testing.T) { ts := StartTest(nil) defer ts.Close() @@ -1269,3 +1306,52 @@ func TestJSONToFormValues(t *testing.T) { } }) } + +func TestPurgeOAuthClientTokensEvent(t *testing.T) { + conf := func(globalConf *config.Config) { + // set tokens to be expired after 1 second + globalConf.OauthTokenExpire = 1 + // cleanup tokens older than 2 seconds + globalConf.OauthTokenExpiredRetainPeriod = 2 + } + + ts := StartTest(conf) + defer ts.Close() + + assertTokensLen := func(t *testing.T, storageManager storage.Handler, storageKey string, expectedTokensLen int) { + nowTs := time.Now().Unix() + startScore := strconv.FormatInt(nowTs, 10) + tokens, _, err := storageManager.GetSortedSetRange(storageKey, startScore, "+inf") + assert.NoError(t, err) + assert.Equal(t, expectedTokensLen, len(tokens)) + } + + spec := ts.LoadTestOAuthSpec() + + clientID1, clientID2 := uuid.New(), uuid.New() + + ts.createOAuthClientIDAndTokens(t, spec, clientID1) + ts.createOAuthClientIDAndTokens(t, spec, clientID2) + storageKey1, storageKey2 := fmt.Sprintf("%s%s", prefixClientTokens, clientID1), + fmt.Sprintf("%s%s", prefixClientTokens, clientID2) + + storageManager := ts.Gw.getGlobalStorageHandler(generateOAuthPrefix(spec.APIID), false) + storageManager.Connect() + + assertTokensLen(t, storageManager, storageKey1, 3) + assertTokensLen(t, storageManager, storageKey2, 3) + + time.Sleep(time.Second * 3) + + // emit event + + n := Notification{ + Command: OAuthPurgeLapsedTokens, + Gw: ts.Gw, + } + ts.Gw.MainNotifier.Notify(n) + + assertTokensLen(t, storageManager, storageKey1, 0) + assertTokensLen(t, storageManager, storageKey2, 0) + +} diff --git a/gateway/redis_signals.go b/gateway/redis_signals.go index e790d9b5244..e14d569e622 100644 --- a/gateway/redis_signals.go +++ b/gateway/redis_signals.go @@ -34,6 +34,7 @@ const ( NoticeGatewayDRLNotification NotificationCommand = "NoticeGatewayDRLNotification" NoticeGatewayLENotification NotificationCommand = "NoticeGatewayLENotification" KeySpaceUpdateNotification NotificationCommand = "KeySpaceUpdateNotification" + OAuthPurgeLapsedTokens NotificationCommand = "OAuthPurgeLapsedTokens" ) // Notification is a type that encodes a message published to a pub sub channel (shared between implementations) @@ -119,6 +120,10 @@ func (gw *Gateway) handleRedisEvent(v interface{}, handled func(NotificationComm gw.reloadURLStructure(reloaded) case KeySpaceUpdateNotification: gw.handleKeySpaceEventCacheFlush(notif.Payload) + case OAuthPurgeLapsedTokens: + if err := gw.purgeLapsedOAuthTokens(); err != nil { + log.WithError(err).Errorf("error while purging tokens for event %s", OAuthPurgeLapsedTokens) + } default: pubSubLog.Warnf("Unknown notification command: %q", notif.Command) return diff --git a/gateway/server.go b/gateway/server.go index 6a70607623f..72510cd04ba 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -625,6 +625,7 @@ func (gw *Gateway) loadControlAPIEndpoints(muxer *mux.Router) { r.HandleFunc("/oauth/clients/{apiID}", gw.oAuthClientHandler).Methods("GET", "DELETE") r.HandleFunc("/oauth/clients/{apiID}/{keyName:[^/]*}", gw.oAuthClientHandler).Methods("GET", "DELETE") r.HandleFunc("/oauth/clients/{apiID}/{keyName}/tokens", gw.oAuthClientTokensHandler).Methods("GET") + r.HandleFunc("/oauth/tokens", gw.oAuthTokensHandler).Methods(http.MethodDelete) mainLog.Debug("Loaded API Endpoints") } diff --git a/go.mod b/go.mod index 260ad920821..c57865f6b7c 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,9 @@ require ( github.com/hashicorp/go-hclog v0.14.1 // indirect github.com/hashicorp/go-immutable-radix v1.3.0 // indirect github.com/hashicorp/go-msgpack v0.5.5 // indirect + github.com/hashicorp/go-multierror v1.1.0 github.com/hashicorp/go-retryablehttp v0.6.7 // indirect + github.com/hashicorp/go-version v1.1.0 github.com/hashicorp/memberlist v0.1.6 // indirect github.com/hashicorp/serf v0.8.6 // indirect github.com/hashicorp/vault/api v1.0.5-0.20200717191844-f687267c8086 @@ -108,3 +110,6 @@ require ( ) //replace github.com/jensneuse/graphql-go-tools => ../graphql-go-tools +replace sourcegraph.com/sourcegraph/appdash => github.com/sourcegraph/appdash v0.0.0-20211028080628-e2786a622600 + +replace sourcegraph.com/sourcegraph/appdash-data => github.com/sourcegraph/appdash-data v0.0.0-20151005221446-73f23eafcf67 diff --git a/go.sum b/go.sum index 0a5099f2c13..0630727129d 100644 --- a/go.sum +++ b/go.sum @@ -271,6 +271,7 @@ github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/b github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-version v1.1.0 h1:bPIoEKD27tNdebFGGxxYwcL4nepeY4j1QP23PFRGzg0= github.com/hashicorp/go-version v1.1.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -519,6 +520,8 @@ github.com/shurcooL/vfsgen v0.0.0-20180121065927-ffb13db8def0/go.mod h1:TrYk7fJV github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sourcegraph/appdash v0.0.0-20211028080628-e2786a622600/go.mod h1:V952P4GGl1v/MMynLwxVdWEbSZJx+n0oOO3ljnez+WU= +github.com/sourcegraph/appdash-data v0.0.0-20151005221446-73f23eafcf67/go.mod h1:tNZjgbYncKL5HxvDULAr/mWDmFz4B7H8yrXEDlnoIiw= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/afero v1.6.0 h1:xoax2sJ2DT8S8xA2paPFjDCScCNeWsg75VG0DLRreiY= github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 00000000000..4faa273d4ae --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,15 @@ +package errors + +import "strings" + +func Formatter(errs []error) string { + var result strings.Builder + for i, err := range errs { + result.WriteString(err.Error()) + if i < len(errs)-1 { + result.WriteString("\n") + } + } + + return result.String() +} diff --git a/internal/errors/errors_test.go b/internal/errors/errors_test.go new file mode 100644 index 00000000000..e4d8f9ff56d --- /dev/null +++ b/internal/errors/errors_test.go @@ -0,0 +1,39 @@ +package errors + +import ( + "errors" + "testing" +) + +func TestErrorFormatter(t *testing.T) { + tests := []struct { + name string + errs []error + expected string + }{ + { + name: "No errors", + errs: []error{}, + expected: "", + }, + { + name: "Single error", + errs: []error{errors.New("error 1")}, + expected: "error 1", + }, + { + name: "Multiple errors", + errs: []error{errors.New("error 1"), errors.New("error 2")}, + expected: "error 1\nerror 2", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := Formatter(tc.errs) + if result != tc.expected { + t.Errorf("Formatter() = %v, want %v", result, tc.expected) + } + }) + } +} diff --git a/internal/url/url.go b/internal/url/url.go new file mode 100644 index 00000000000..8eaa6d050e1 --- /dev/null +++ b/internal/url/url.go @@ -0,0 +1,11 @@ +package url + +import "net/url" + +type URL = url.URL + +// QueryHas checks whether a given key is set. +func QueryHas(v url.Values, key string) bool { + _, ok := v[key] + return ok +} diff --git a/internal/url/url_test.go b/internal/url/url_test.go new file mode 100644 index 00000000000..69490e5f062 --- /dev/null +++ b/internal/url/url_test.go @@ -0,0 +1,55 @@ +package url + +import ( + "net/url" + "testing" +) + +func TestQueryHas(t *testing.T) { + tests := []struct { + name string + values url.Values + key string + expected bool + }{ + { + name: "Key present", + values: url.Values{"test": []string{"value"}}, + key: "test", + expected: true, + }, + { + name: "Key absent", + values: url.Values{"test": []string{"value"}}, + key: "missing", + expected: false, + }, + { + name: "Empty values", + values: url.Values{}, + key: "any", + expected: false, + }, + { + name: "Multiple keys, target present", + values: url.Values{"test": []string{"value"}, "another": []string{"value2"}}, + key: "another", + expected: true, + }, + { + name: "Multiple keys, target absent", + values: url.Values{"test": []string{"value"}, "another": []string{"value2"}}, + key: "missing", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := QueryHas(tt.values, tt.key) + if result != tt.expected { + t.Errorf("QueryHas(%v, %q) = %v; want %v", tt.values, tt.key, result, tt.expected) + } + }) + } +} diff --git a/storage/redis_cluster.go b/storage/redis_cluster.go index 26f8791a42c..e37c70599fa 100644 --- a/storage/redis_cluster.go +++ b/storage/redis_cluster.go @@ -430,64 +430,15 @@ func (r *RedisCluster) IncrememntWithExpire(keyName string, expire int64) int64 // GetKeys will return all keys according to the filter (filter is a prefix - e.g. tyk.keys.*) func (r *RedisCluster) GetKeys(filter string) []string { - if err := r.up(); err != nil { - log.Debug(err) - return nil - } - - singleton, err := r.singleton() - if err != nil { - log.Error(err) - return nil - } - filterHash := "" if filter != "" { filterHash = r.hashKey(filter) } + searchStr := r.KeyPrefix + filterHash + "*" log.Debug("[STORE] Getting list by: ", searchStr) - fnFetchKeys := func(client *redis.Client) ([]string, error) { - values := make([]string, 0) - - iter := client.Scan(r.RedisController.ctx, 0, searchStr, 0).Iterator() - for iter.Next(r.RedisController.ctx) { - values = append(values, iter.Val()) - } - - if err := iter.Err(); err != nil { - return nil, err - } - - return values, nil - } - - sessions := make([]string, 0) - - switch v := singleton.(type) { - case *redis.ClusterClient: - ch := make(chan []string) - - go func() { - err = v.ForEachMaster(r.RedisController.ctx, func(ctx context.Context, client *redis.Client) error { - values, err := fnFetchKeys(client) - if err != nil { - return err - } - - ch <- values - return nil - }) - close(ch) - }() - - for res := range ch { - sessions = append(sessions, res...) - } - case *redis.Client: - sessions, err = fnFetchKeys(v) - } + sessions, err := r.ScanKeys(searchStr) if err != nil { log.Error("Error while fetching keys:", err) @@ -1148,7 +1099,7 @@ func (r *RedisCluster) SetRollingWindow(keyName string, per int64, value_overrid return intVal, result } -func (r RedisCluster) GetRollingWindow(keyName string, per int64, pipeline bool) (int, []interface{}) { +func (r *RedisCluster) GetRollingWindow(keyName string, per int64, pipeline bool) (int, []interface{}) { if err := r.up(); err != nil { log.Debug(err) return 0, nil @@ -1298,3 +1249,67 @@ func (r *RedisCluster) RemoveSortedSetRange(keyName, scoreFrom, scoreTo string) func (r *RedisCluster) ControllerInitiated() bool { return r.RedisController != nil } + +// ScanKeys will return all keys according to the pattern. +func (r *RedisCluster) ScanKeys(pattern string) ([]string, error) { + if err := r.up(); err != nil { + log.Debug(err) + return nil, err + } + + singleton, err := r.singleton() + if err != nil { + log.Error(err) + return nil, err + } + + log.Debug("[STORE] scanning keys by: ", pattern) + + fnFetchKeys := func(client *redis.Client) ([]string, error) { + values := make([]string, 0) + + iter := client.Scan(r.RedisController.ctx, 0, pattern, 0).Iterator() + for iter.Next(r.RedisController.ctx) { + values = append(values, iter.Val()) + } + + if err := iter.Err(); err != nil { + return nil, err + } + + return values, nil + } + + keys := make([]string, 0) + + switch v := singleton.(type) { + case *redis.ClusterClient: + ch := make(chan []string) + + go func() { + err = v.ForEachMaster(r.RedisController.ctx, func(ctx context.Context, client *redis.Client) error { + values, err := fnFetchKeys(client) + if err != nil { + return err + } + + ch <- values + return nil + }) + close(ch) + }() + + for res := range ch { + keys = append(keys, res...) + } + case *redis.Client: + keys, err = fnFetchKeys(v) + } + + if err != nil { + log.Error("Error while scanning for keys:", err) + return nil, err + } + + return keys, nil +}