diff --git a/internal/policy/Taskfile.yml b/internal/policy/Taskfile.yml index 8f5cf09cff2..97c1c82735d 100644 --- a/internal/policy/Taskfile.yml +++ b/internal/policy/Taskfile.yml @@ -19,8 +19,24 @@ tasks: - defer: { task: services:down } - goimports -w . - go fmt ./... + - task: test + vars: + run: '{{.run}}' + + test: + desc: "Run tests" + requires: + vars: [run] + cmds: - go test -count=1 -run='({{.run}})' -cover -coverprofile=pkg.cov -v . + stress: + desc: "Run stress tests" + requires: + vars: [run] + cmds: + - go test -count=2000 -run='({{.run}})' -cover -coverprofile=pkg.cov . + cover: desc: "Show source coverage" aliases: [coverage, cov] diff --git a/internal/policy/apply.go b/internal/policy/apply.go index 6baafaecf2c..8a24ce7ab3a 100644 --- a/internal/policy/apply.go +++ b/internal/policy/apply.go @@ -22,6 +22,7 @@ type Repository interface { PolicyByID(string) (user.Policy, bool) } +// Service represents the policy service for gateway. type Service struct { storage Repository logger *logrus.Logger @@ -30,6 +31,7 @@ type Service struct { orgID *string } +// New creates a new policy.Service object. func New(orgID *string, storage Repository, logger *logrus.Logger) *Service { return &Service{ orgID: orgID, @@ -107,7 +109,8 @@ func (t *Service) Apply(session *user.SessionState) error { ) storage := t.storage - customPolicies, err := session.CustomPolicies() + + customPolicies, err := session.GetCustomPolicies() if err != nil { policyIDs = session.PolicyIDs() } else { @@ -242,8 +245,9 @@ func (t *Service) Apply(session *user.SessionState) error { return nil } -func (t *Service) Logger() *logrus.Logger { - return t.logger +// Logger implements a typical logger signature with service context. +func (t *Service) Logger() *logrus.Entry { + return logrus.NewEntry(t.logger) } // ApplyRateLimits will write policy limits to session and apiLimits. @@ -355,7 +359,7 @@ func (t *Service) applyPartitions(policy user.Policy, session *user.SessionState if !usePartitions || policy.Partitions.Acl { applyState.didAcl[k] = true - ar.AllowedURLs = copyAllowedURLs(v.AllowedURLs) + ar.AllowedURLs = MergeAllowedURLs(ar.AllowedURLs, v.AllowedURLs) // Merge ACLs for the same API if r, ok := rights[k]; ok { @@ -365,19 +369,7 @@ func (t *Service) applyPartitions(policy user.Policy, session *user.SessionState } r.Versions = appendIfMissing(rights[k].Versions, v.Versions...) - for _, u := range v.AllowedURLs { - found := false - for ai, au := range r.AllowedURLs { - if u.URL == au.URL { - found = true - r.AllowedURLs[ai].Methods = appendIfMissing(au.Methods, u.Methods...) - } - } - - if !found { - r.AllowedURLs = append(r.AllowedURLs, v.AllowedURLs...) - } - } + r.AllowedURLs = MergeAllowedURLs(r.AllowedURLs, v.AllowedURLs) for _, t := range v.RestrictedTypes { for ri, rt := range r.RestrictedTypes { diff --git a/internal/policy/apply_test.go b/internal/policy/apply_test.go index a0fede15c61..37b77e4a4a5 100644 --- a/internal/policy/apply_test.go +++ b/internal/policy/apply_test.go @@ -9,11 +9,9 @@ import ( "testing" "github.com/sirupsen/logrus" - - "github.com/TykTechnologies/graphql-go-tools/pkg/graphql" - "github.com/stretchr/testify/assert" + "github.com/TykTechnologies/graphql-go-tools/pkg/graphql" "github.com/TykTechnologies/tyk/internal/policy" "github.com/TykTechnologies/tyk/user" ) @@ -22,9 +20,9 @@ import ( var testDataFS embed.FS func TestApplyRateLimits_PolicyLimits(t *testing.T) { - svc := &policy.Service{} - t.Run("policy limits unset", func(t *testing.T) { + svc := &policy.Service{} + session := &user.SessionState{ Rate: 5, Per: 10, @@ -44,6 +42,8 @@ func TestApplyRateLimits_PolicyLimits(t *testing.T) { }) t.Run("policy limits apply all", func(t *testing.T) { + svc := &policy.Service{} + session := &user.SessionState{ Rate: 5, Per: 10, @@ -69,6 +69,8 @@ func TestApplyRateLimits_PolicyLimits(t *testing.T) { // changes are applied to api limits, but skipped on // the session as the session has a higher allowance. t.Run("policy limits apply per-api", func(t *testing.T) { + svc := &policy.Service{} + session := &user.SessionState{ Rate: 15, Per: 10, @@ -93,6 +95,8 @@ func TestApplyRateLimits_PolicyLimits(t *testing.T) { // As the policy defined a lower rate than apiLimits, // no changes to api limits are applied. t.Run("policy limits skip", func(t *testing.T) { + svc := &policy.Service{} + session := &user.SessionState{ Rate: 5, Per: 10, @@ -117,29 +121,26 @@ func TestApplyRateLimits_PolicyLimits(t *testing.T) { func TestApplyRateLimits_FromCustomPolicies(t *testing.T) { svc := &policy.Service{} - t.Run("Custom policies", func(t *testing.T) { - session := &user.SessionState{} - session.SetCustomPolicies([]user.Policy{ - { - ID: "pol1", - Partitions: user.PolicyPartitions{RateLimit: true}, - Rate: 8, - Per: 1, - AccessRights: map[string]user.AccessDefinition{"a": {}}, - }, - { - ID: "pol2", - Partitions: user.PolicyPartitions{RateLimit: true}, - Rate: 10, - Per: 1, - AccessRights: map[string]user.AccessDefinition{"a": {}}, - }, - }) - - svc.Apply(session) - - assert.Equal(t, 10, int(session.Rate)) + session := &user.SessionState{} + session.SetCustomPolicies([]user.Policy{ + { + ID: "pol1", + Partitions: user.PolicyPartitions{RateLimit: true}, + Rate: 8, + Per: 1, + AccessRights: map[string]user.AccessDefinition{"a": {}}, + }, + { + ID: "pol2", + Partitions: user.PolicyPartitions{RateLimit: true}, + Rate: 10, + Per: 1, + AccessRights: map[string]user.AccessDefinition{"a": {}}, + }, }) + + assert.NoError(t, svc.Apply(session)) + assert.Equal(t, 10, int(session.Rate)) } func TestApplyEndpointLevelLimits(t *testing.T) { @@ -190,7 +191,7 @@ func testPrepareApplyPolicies(tb testing.TB) (*policy.Service, []testApplyPolici err = json.Unmarshal(f, &repoPols) assert.NoError(tb, err) - store := policy.NewStore(repoPols) + store := policy.NewStoreMap(repoPols) orgID := "" service := policy.New(&orgID, store, logrus.StandardLogger()) diff --git a/internal/policy/store.go b/internal/policy/store.go index 909c53e8bca..7829659db89 100644 --- a/internal/policy/store.go +++ b/internal/policy/store.go @@ -4,20 +4,27 @@ import ( "github.com/TykTechnologies/tyk/user" ) -// Store is an in-memory policy storage object that -// implements the repository for policy access. We -// do not implement concurrency protections here. +// Store is an in-memory policy storage object that implements the +// repository for policy access. We do not implement concurrency +// protections here. Where order is important, use this. type Store struct { - policies map[string]user.Policy + policies []user.Policy } -func NewStore(policies map[string]user.Policy) *Store { +// NewStore returns a new policy.Store. +func NewStore(policies []user.Policy) *Store { return &Store{ policies: policies, } } +// PolicyIDs returns a list policy IDs in the store. +// It will return nil if no policies exist. func (s *Store) PolicyIDs() []string { + if len(s.policies) == 0 { + return nil + } + policyIDs := make([]string, 0, len(s.policies)) for _, val := range s.policies { policyIDs = append(policyIDs, val.ID) @@ -25,11 +32,17 @@ func (s *Store) PolicyIDs() []string { return policyIDs } +// PolicyByID returns a policy by ID. func (s *Store) PolicyByID(id string) (user.Policy, bool) { - v, ok := s.policies[id] - return v, ok + for _, pol := range s.policies { + if pol.ID == id { + return pol, true + } + } + return user.Policy{}, false } +// PolicyCount returns the number of policies in the store. func (s *Store) PolicyCount() int { return len(s.policies) } diff --git a/internal/policy/store_map.go b/internal/policy/store_map.go new file mode 100644 index 00000000000..a035c320a4a --- /dev/null +++ b/internal/policy/store_map.go @@ -0,0 +1,46 @@ +package policy + +import ( + "github.com/TykTechnologies/tyk/user" +) + +// StoreMap is same as Store, but doesn't preserve order. +type StoreMap struct { + policies map[string]user.Policy +} + +// NewStoreMap returns a new policy.StoreMap. +func NewStoreMap(policies map[string]user.Policy) *StoreMap { + if len(policies) == 0 { + policies = make(map[string]user.Policy) + } + + return &StoreMap{ + policies: policies, + } +} + +// PolicyIDs returns a list policy IDs in the store. +// It will return nil if no policies exist. +func (s *StoreMap) PolicyIDs() []string { + if len(s.policies) == 0 { + return nil + } + + policyIDs := make([]string, 0, len(s.policies)) + for _, val := range s.policies { + policyIDs = append(policyIDs, val.ID) + } + return policyIDs +} + +// PolicyByID returns a policy by ID. +func (s *StoreMap) PolicyByID(id string) (user.Policy, bool) { + v, ok := s.policies[id] + return v, ok +} + +// PolicyCount returns the number of policies in the store. +func (s *StoreMap) PolicyCount() int { + return len(s.policies) +} diff --git a/internal/policy/util.go b/internal/policy/util.go index ed34211c0f4..8558fed0800 100644 --- a/internal/policy/util.go +++ b/internal/policy/util.go @@ -1,42 +1,62 @@ package policy import ( + "slices" + "github.com/TykTechnologies/tyk/user" ) -// appendIfMissing ensures dest slice is unique with new items. -func appendIfMissing(src []string, in ...string) []string { - // Use map for uniqueness - srcMap := map[string]bool{} - for _, v := range src { - srcMap[v] = true - } - for _, v := range in { - srcMap[v] = true - } - - // Produce unique []string, maintain sort order - uniqueSorted := func(src []string, keys map[string]bool) []string { - result := make([]string, 0, len(keys)) - for _, v := range src { - // append missing value - if val := keys[v]; val { - result = append(result, v) - delete(keys, v) +// MergeAllowedURLs will merge s1 and s2 to produce a merged result. +// It maintains order of keys in s1 and s2 as they are seen. +// If the result is an empty set, nil is returned. +func MergeAllowedURLs(s1, s2 []user.AccessSpec) []user.AccessSpec { + order := []string{} + merged := map[string][]string{} + + // Loop input sets and merge through a map. + for _, src := range [][]user.AccessSpec{s1, s2} { + for _, r := range src { + url := r.URL + v, ok := merged[url] + if !ok { + // First time we see the spec + merged[url] = r.Methods + + // Maintain order + order = append(order, url) + + continue } + merged[url] = appendIfMissing(v, r.Methods...) } - return result } - // no new items from `in` - if len(srcMap) == len(src) { - return src + // Early exit without allocating. + if len(order) == 0 { + return nil } - src = uniqueSorted(src, srcMap) - in = uniqueSorted(in, srcMap) + // Provide results in desired order. + result := make([]user.AccessSpec, 0, len(order)) + for _, key := range order { + spec := user.AccessSpec{ + Methods: merged[key], + URL: key, + } + result = append(result, spec) + } + return result +} - return append(src, in...) +// appendIfMissing ensures dest slice is unique with new items. +func appendIfMissing(dest []string, in ...string) []string { + for _, v := range in { + if slices.Contains(dest, v) { + continue + } + dest = append(dest, v) + } + return dest } // intersection gets intersection of the given two slices. @@ -56,30 +76,6 @@ func intersection(a []string, b []string) (inter []string) { return } -// contains checks whether the given slice contains the given item. -func contains(s []string, i string) bool { - for _, a := range s { - if a == i { - return true - } - } - return false -} - -// greaterThanFloat64 checks whether first float64 value is bigger than second float64 value. -// -1 means infinite and the biggest value. -func greaterThanFloat64(first, second float64) bool { - if first == -1 { - return true - } - - if second == -1 { - return false - } - - return first > second -} - // greaterThanInt64 checks whether first int64 value is bigger than second int64 value. // -1 means infinite and the biggest value. func greaterThanInt64(first, second int64) bool { @@ -107,23 +103,3 @@ func greaterThanInt(first, second int) bool { return first > second } - -func copyAllowedURLs(input []user.AccessSpec) []user.AccessSpec { - if input == nil { - return nil - } - - copied := make([]user.AccessSpec, len(input)) - - for i, as := range input { - copied[i] = user.AccessSpec{ - URL: as.URL, - } - if as.Methods != nil { - copied[i].Methods = make([]string, len(as.Methods)) - copy(copied[i].Methods, as.Methods) - } - } - - return copied -} diff --git a/internal/policy/util_test.go b/internal/policy/util_test.go new file mode 100644 index 00000000000..460d0cfb119 --- /dev/null +++ b/internal/policy/util_test.go @@ -0,0 +1,64 @@ +package policy_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/TykTechnologies/tyk/internal/policy" + "github.com/TykTechnologies/tyk/user" +) + +func TestMergeAllowedURLs(t *testing.T) { + svc := &policy.Service{} + + session := &user.SessionState{} + policies := []user.Policy{ + { + ID: "pol1", + AccessRights: map[string]user.AccessDefinition{ + "a": { + AllowedURLs: []user.AccessSpec{ + {URL: "/user", Methods: []string{"GET"}}, + {URL: "/companies", Methods: []string{"GET"}}, + }, + }, + }, + }, + { + ID: "pol2", + AccessRights: map[string]user.AccessDefinition{ + "a": { + AllowedURLs: []user.AccessSpec{ + {URL: "/user", Methods: []string{"POST", "PATCH", "PUT"}}, + {URL: "/companies", Methods: []string{"POST"}}, + {URL: "/admin", Methods: []string{"GET", "POST"}}, + }, + }, + }, + }, + { + ID: "pol3", + AccessRights: map[string]user.AccessDefinition{ + "a": { + AllowedURLs: []user.AccessSpec{ + {URL: "/admin/cache", Methods: []string{"DELETE"}}, + }, + }, + }, + }, + } + + session.SetCustomPolicies(policies) + + assert.NoError(t, svc.Apply(session)) + + want := []user.AccessSpec{ + {URL: "/user", Methods: []string{"GET", "POST", "PATCH", "PUT"}}, + {URL: "/companies", Methods: []string{"GET", "POST"}}, + {URL: "/admin", Methods: []string{"GET", "POST"}}, + {URL: "/admin/cache", Methods: []string{"DELETE"}}, + } + + assert.Equal(t, want, session.AccessRights["a"].AllowedURLs) +} diff --git a/user/custom_policies.go b/user/custom_policies.go index bdbb7f3d12a..3ac8c852b92 100644 --- a/user/custom_policies.go +++ b/user/custom_policies.go @@ -6,10 +6,26 @@ import ( "fmt" ) +// CustomPolicies returns a map of custom policies on the session. +// To preserve policy order, use GetCustomPolicies instead. func (s *SessionState) CustomPolicies() (map[string]Policy, error) { + customPolicies, err := s.GetCustomPolicies() + if err != nil { + return nil, err + } + + result := make(map[string]Policy, len(customPolicies)) + for i := 0; i < len(customPolicies); i++ { + result[customPolicies[i].ID] = customPolicies[i] + } + + return result, nil +} + +// GetCustomPolicies is like CustomPolicies but returns the list, preserving order. +func (s *SessionState) GetCustomPolicies() ([]Policy, error) { var ( customPolicies []Policy - ret map[string]Policy ) metadataPolicies, found := s.MetaData["policies"].([]interface{}) @@ -22,16 +38,14 @@ func (s *SessionState) CustomPolicies() (map[string]Policy, error) { return nil, fmt.Errorf("failed to marshal metadata policies: %w", err) } - _ = json.Unmarshal(polJSON, &customPolicies) - - ret = make(map[string]Policy, len(customPolicies)) - for i := 0; i < len(customPolicies); i++ { - ret[customPolicies[i].ID] = customPolicies[i] + if err := json.Unmarshal(polJSON, &customPolicies); err != nil { + return nil, fmt.Errorf("failed to unmarshal metadata policies: %w", err) } - return ret, nil + return customPolicies, err } +// SetCustomPolicies sets custom policies into session metadata. func (s *SessionState) SetCustomPolicies(list []Policy) { if s.MetaData == nil { s.MetaData = make(map[string]interface{}) diff --git a/user/session.go b/user/session.go index 929e8efe2d3..ec6ebfbf11b 100644 --- a/user/session.go +++ b/user/session.go @@ -10,11 +10,8 @@ import ( "github.com/TykTechnologies/graphql-go-tools/pkg/graphql" "github.com/TykTechnologies/tyk/apidef" - logger "github.com/TykTechnologies/tyk/log" ) -var log = logger.Get() - type HashType string const (