From 1c095b310826e63788d892ed8829ef86cfcf3c1a Mon Sep 17 00:00:00 2001 From: Tit Petric Date: Thu, 5 Sep 2024 20:49:19 +0200 Subject: [PATCH] Move rate limit info test --- gateway/api_definition_test.go | 2 +- gateway/session_manager.go | 4 +- gateway/session_manager_test.go | 155 ++++++++++++++++++++++++++++++++ user/session_test.go | 147 ------------------------------ 4 files changed, 158 insertions(+), 150 deletions(-) diff --git a/gateway/api_definition_test.go b/gateway/api_definition_test.go index 7f62b3dd886..7817b641cc3 100644 --- a/gateway/api_definition_test.go +++ b/gateway/api_definition_test.go @@ -418,7 +418,7 @@ func TestConflictingPaths(t *testing.T) { } func TestIgnored(t *testing.T) { - ts := StartTest(func (c *config.Config) { + ts := StartTest(func(c *config.Config) { c.HttpServerOptions.EnablePrefixMatching = true }) defer ts.Close() diff --git a/gateway/session_manager.go b/gateway/session_manager.go index 4067cd5f7ca..bce1f784741 100644 --- a/gateway/session_manager.go +++ b/gateway/session_manager.go @@ -228,7 +228,7 @@ func (sfr sessionFailReason) String() string { } } -func (l *SessionLimiter) rateLimitInfo(r *http.Request, api *APISpec, endpoints user.Endpoints) (*user.EndpointRateLimitInfo, bool) { +func (l *SessionLimiter) RateLimitInfo(r *http.Request, api *APISpec, endpoints user.Endpoints) (*user.EndpointRateLimitInfo, bool) { // Hook per-api settings here (m.Spec...) isPrefixMatch := l.config.HttpServerOptions.EnablePrefixMatching isSuffixMatch := l.config.HttpServerOptions.EnableSuffixMatching @@ -291,7 +291,7 @@ func (l *SessionLimiter) ForwardMessage(r *http.Request, session *user.SessionSt endpointRLKeySuffix = "" ) - endpointRLInfo, doEndpointRL := l.rateLimitInfo(r, api, accessDef.Endpoints) + endpointRLInfo, doEndpointRL := l.RateLimitInfo(r, api, accessDef.Endpoints) if doEndpointRL { apiLimit.Rate = endpointRLInfo.Rate apiLimit.Per = endpointRLInfo.Per diff --git a/gateway/session_manager_test.go b/gateway/session_manager_test.go index 847ef644b5d..9f54069ec77 100644 --- a/gateway/session_manager_test.go +++ b/gateway/session_manager_test.go @@ -1,11 +1,15 @@ package gateway import ( + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/storage" "github.com/TykTechnologies/tyk/user" ) @@ -140,3 +144,154 @@ func TestGetAccessDefinitionByAPIIDOrSession(t *testing.T) { assert.NoError(t, err) }) } + +func TestSessionLimiter_RateLimitInfo(t *testing.T) { + limiter := &SessionLimiter{config: &config.Default} + spec := BuildAPI(func(a *APISpec) { + a.Proxy.ListenPath = "/" + })[0] + + tests := []struct { + name string + method string + path string + endpoints user.Endpoints + expected *user.EndpointRateLimitInfo + found bool + }{ + { + name: "Matching endpoint and method", + method: http.MethodGet, + path: "/api/v1/users", + endpoints: user.Endpoints{ + { + Path: "/api/v1/users", + Methods: []user.EndpointMethod{ + {Name: "GET", Limit: user.RateLimit{Rate: 100, Per: 60}}, + }, + }, + }, + expected: &user.EndpointRateLimitInfo{ + KeySuffix: storage.HashStr("GET:/api/v1/users"), + Rate: 100, + Per: 60, + }, + found: true, + }, + { + name: "Matching endpoint, non-matching method", + path: "/api/v1/users", + method: http.MethodPost, + endpoints: []user.Endpoint{ + { + Path: "/api/v1/users", + Methods: []user.EndpointMethod{ + {Name: "GET", Limit: user.RateLimit{Rate: 100, Per: 60}}, + }, + }, + }, + expected: nil, + found: false, + }, + { + name: "Non-matching endpoint", + method: http.MethodGet, + path: "/api/v1/products", + endpoints: []user.Endpoint{ + { + Path: "/api/v1/users", + Methods: []user.EndpointMethod{ + {Name: "GET", Limit: user.RateLimit{Rate: 100, Per: 60}}, + }, + }, + }, + expected: nil, + found: false, + }, + { + name: "Regex path matching", + path: "/api/v1/users/123", + method: http.MethodGet, + endpoints: []user.Endpoint{ + { + Path: "/api/v1/users/[0-9]+", + Methods: []user.EndpointMethod{ + {Name: "GET", Limit: user.RateLimit{Rate: 50, Per: 30}}, + }, + }, + }, + expected: &user.EndpointRateLimitInfo{ + KeySuffix: storage.HashStr("GET:/api/v1/users/[0-9]+"), + Rate: 50, + Per: 30, + }, + found: true, + }, + { + name: "Invalid regex path", + path: "/api/v1/users", + method: http.MethodGet, + endpoints: []user.Endpoint{ + { + Path: "[invalid regex", + Methods: []user.EndpointMethod{ + {Name: "GET", Limit: user.RateLimit{Rate: 100, Per: 60}}, + }, + }, + }, + expected: nil, + found: false, + }, + { + name: "Invalid regex path and valid url", + path: "/api/v1/users", + method: http.MethodGet, + endpoints: []user.Endpoint{ + { + Path: "[invalid regex", + Methods: []user.EndpointMethod{ + {Name: "GET", Limit: user.RateLimit{Rate: 100, Per: 60}}, + }, + }, + { + Path: "/api/v1/users", + Methods: []user.EndpointMethod{ + {Name: "GET", Limit: user.RateLimit{Rate: 100, Per: 60}}, + }, + }, + }, + expected: &user.EndpointRateLimitInfo{ + KeySuffix: storage.HashStr("GET:/api/v1/users"), + Rate: 100, + Per: 60, + }, + found: true, + }, + { + name: "nil endpoints", + path: "/api/v1/users", + method: http.MethodGet, + endpoints: nil, + expected: nil, + found: false, + }, + { + name: "empty endpoints", + path: "/api/v1/users", + method: http.MethodGet, + endpoints: user.Endpoints{}, + expected: nil, + found: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + + result, found := limiter.RateLimitInfo(req, spec, tt.endpoints) + assert.Equal(t, tt.found, found) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/user/session_test.go b/user/session_test.go index 74914ea46aa..9195e7f1112 100644 --- a/user/session_test.go +++ b/user/session_test.go @@ -2,13 +2,10 @@ package user import ( "encoding/json" - "net/http" "reflect" "testing" "time" - "github.com/TykTechnologies/tyk/storage" - "github.com/TykTechnologies/tyk/apidef" "github.com/stretchr/testify/assert" @@ -377,150 +374,6 @@ func TestAPILimit_Clone(t *testing.T) { } } -func TestEndpoints_RateLimitInfo(t *testing.T) { - tests := []struct { - name string - method string - path string - endpoints Endpoints - expected *EndpointRateLimitInfo - found bool - }{ - { - name: "Matching endpoint and method", - method: http.MethodGet, - path: "/api/v1/users", - endpoints: Endpoints{ - { - Path: "/api/v1/users", - Methods: []EndpointMethod{ - {Name: "GET", Limit: RateLimit{Rate: 100, Per: 60}}, - }, - }, - }, - expected: &EndpointRateLimitInfo{ - KeySuffix: storage.HashStr("GET:/api/v1/users"), - Rate: 100, - Per: 60, - }, - found: true, - }, - { - name: "Matching endpoint, non-matching method", - path: "/api/v1/users", - method: http.MethodPost, - endpoints: []Endpoint{ - { - Path: "/api/v1/users", - Methods: []EndpointMethod{ - {Name: "GET", Limit: RateLimit{Rate: 100, Per: 60}}, - }, - }, - }, - expected: nil, - found: false, - }, - { - name: "Non-matching endpoint", - method: http.MethodGet, - path: "/api/v1/products", - endpoints: []Endpoint{ - { - Path: "/api/v1/users", - Methods: []EndpointMethod{ - {Name: "GET", Limit: RateLimit{Rate: 100, Per: 60}}, - }, - }, - }, - expected: nil, - found: false, - }, - { - name: "Regex path matching", - path: "/api/v1/users/123", - method: http.MethodGet, - endpoints: []Endpoint{ - { - Path: "/api/v1/users/[0-9]+", - Methods: []EndpointMethod{ - {Name: "GET", Limit: RateLimit{Rate: 50, Per: 30}}, - }, - }, - }, - expected: &EndpointRateLimitInfo{ - KeySuffix: storage.HashStr("GET:/api/v1/users/[0-9]+"), - Rate: 50, - Per: 30, - }, - found: true, - }, - { - name: "Invalid regex path", - path: "/api/v1/users", - method: http.MethodGet, - endpoints: []Endpoint{ - { - Path: "[invalid regex", - Methods: []EndpointMethod{ - {Name: "GET", Limit: RateLimit{Rate: 100, Per: 60}}, - }, - }, - }, - expected: nil, - found: false, - }, - { - name: "Invalid regex path and valid url", - path: "/api/v1/users", - method: http.MethodGet, - endpoints: []Endpoint{ - { - Path: "[invalid regex", - Methods: []EndpointMethod{ - {Name: "GET", Limit: RateLimit{Rate: 100, Per: 60}}, - }, - }, - { - Path: "/api/v1/users", - Methods: []EndpointMethod{ - {Name: "GET", Limit: RateLimit{Rate: 100, Per: 60}}, - }, - }, - }, - expected: &EndpointRateLimitInfo{ - KeySuffix: storage.HashStr("GET:/api/v1/users"), - Rate: 100, - Per: 60, - }, - found: true, - }, - { - name: "nil endpoints", - path: "/api/v1/users", - method: http.MethodGet, - endpoints: nil, - expected: nil, - found: false, - }, - { - name: "empty endpoints", - path: "/api/v1/users", - method: http.MethodGet, - endpoints: Endpoints{}, - expected: nil, - found: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, found := tt.endpoints.RateLimitInfo(tt.method, []string{tt.path}) - assert.Equal(t, tt.found, found) - assert.Equal(t, tt.expected, result) - }) - } -} - func TestEndpoints_Map(t *testing.T) { tests := []struct { name string