diff --git a/go.mod b/go.mod index c779790a6..59da50095 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/petar-dambovaliev/aho-corasick v0.0.0-20211021192214-5ab2d9280aa9 github.com/tidwall/gjson v1.14.4 golang.org/x/net v0.11.0 + golang.org/x/sync v0.1.0 rsc.io/binaryregexp v0.2.0 ) diff --git a/go.sum b/go.sum index bf2237977..12f51696a 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,7 @@ golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/corazawaf/rule.go b/internal/corazawaf/rule.go index 7616ccaca..e67e6fb02 100644 --- a/internal/corazawaf/rule.go +++ b/internal/corazawaf/rule.go @@ -15,6 +15,7 @@ import ( "github.com/corazawaf/coraza/v3/experimental/plugins/macro" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" "github.com/corazawaf/coraza/v3/internal/corazarules" + "github.com/corazawaf/coraza/v3/internal/memoize" "github.com/corazawaf/coraza/v3/types" "github.com/corazawaf/coraza/v3/types/variables" ) @@ -456,7 +457,12 @@ func (r *Rule) AddVariable(v variables.RuleVariable, key string, iscount bool) e var re *regexp.Regexp if len(key) > 2 && key[0] == '/' && key[len(key)-1] == '/' { key = key[1 : len(key)-1] - re = regexp.MustCompile(key) + + if vare, err := memoize.Do(key, func() (interface{}, error) { return regexp.Compile(key) }); err != nil { + panic(err) + } else { + re = vare.(*regexp.Regexp) + } } if multiphaseEvaluation { @@ -521,7 +527,11 @@ func (r *Rule) AddVariableNegation(v variables.RuleVariable, key string) error { var re *regexp.Regexp if len(key) > 2 && key[0] == '/' && key[len(key)-1] == '/' { key = key[1 : len(key)-1] - re = regexp.MustCompile(key) + if vare, err := memoize.Do(key, func() (interface{}, error) { return regexp.Compile(key) }); err != nil { + panic(err) + } else { + re = vare.(*regexp.Regexp) + } } // Prevent sigsev if r == nil { diff --git a/internal/memoize/cache.go b/internal/memoize/cache.go new file mode 100644 index 000000000..030509632 --- /dev/null +++ b/internal/memoize/cache.go @@ -0,0 +1,40 @@ +// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +//go:build !tinygo + +// Highly inspired in https://github.com/patrickmn/go-cache/blob/master/cache.go + +package memoize + +import ( + "sync" +) + +type cache struct { + mu sync.RWMutex + entries map[string]interface{} +} + +func newCache() *cache { + return &cache{ + entries: make(map[string]interface{}), + } +} + +func (c *cache) set(key string, value interface{}) { + c.mu.Lock() + c.entries[key] = value + c.mu.Unlock() +} + +func (c *cache) get(key string) (interface{}, bool) { + c.mu.RLock() + item, found := c.entries[key] + if !found { + c.mu.RUnlock() + return nil, false + } + c.mu.RUnlock() + return item, true +} diff --git a/internal/memoize/cache_test.go b/internal/memoize/cache_test.go new file mode 100644 index 000000000..9eec1fa88 --- /dev/null +++ b/internal/memoize/cache_test.go @@ -0,0 +1,28 @@ +// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +//go:build !tinygo + +package memoize + +import "testing" + +func TestCache(t *testing.T) { + tc := newCache() + + _, found := tc.get("key1") + if want, have := false, found; want != have { + t.Fatalf("unexpected value, want %t, have %t", want, have) + } + + tc.set("key1", 1) + + item, found := tc.get("key1") + if want, have := true, found; want != have { + t.Fatalf("unexpected value, want %t, have %t", want, have) + } + + if want, have := 1, item.(int); want != have { + t.Fatalf("unexpected value, want %d, have %d", want, have) + } +} diff --git a/internal/memoize/memoize.go b/internal/memoize/memoize.go new file mode 100644 index 000000000..da90083f6 --- /dev/null +++ b/internal/memoize/memoize.go @@ -0,0 +1,45 @@ +// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +//go:build !tinygo + +// https://github.com/kofalt/go-memoize/blob/master/memoize.go + +package memoize + +import ( + "golang.org/x/sync/singleflight" +) + +var doer = makeDoer(newCache(), &singleflight.Group{}) + +// Do executes and returns the results of the given function, unless there was a cached +// value of the same key. Only one execution is in-flight for a given key at a time. +// The boolean return value indicates whether v was previously stored. +func Do(key string, fn func() (interface{}, error)) (interface{}, error) { + value, err, _ := doer(key, fn) + return value, err +} + +// makeDoer returns a function that executes and returns the results of the given function +func makeDoer(cache *cache, group *singleflight.Group) func(string, func() (interface{}, error)) (interface{}, error, bool) { + return func(key string, fn func() (interface{}, error)) (interface{}, error, bool) { + // Check cache + value, found := cache.get(key) + if found { + return value, nil, true + } + + // Combine memoized function with a cache store + value, err, _ := group.Do(key, func() (interface{}, error) { + data, innerErr := fn() + if innerErr == nil { + cache.set(key, data) + } + + return data, innerErr + }) + + return value, err, false + } +} diff --git a/internal/memoize/memoize_test.go b/internal/memoize/memoize_test.go new file mode 100644 index 000000000..d5c0a480b --- /dev/null +++ b/internal/memoize/memoize_test.go @@ -0,0 +1,128 @@ +// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +//go:build !tinygo + +// https://github.com/kofalt/go-memoize/blob/master/memoize.go + +package memoize + +import ( + "errors" + "testing" + + "golang.org/x/sync/singleflight" +) + +func TestSuccessCall(t *testing.T) { + do := makeDoer(newCache(), &singleflight.Group{}) + + expensiveCalls := 0 + + // Function tracks how many times its been called + expensive := func() (interface{}, error) { + expensiveCalls++ + return expensiveCalls, nil + } + + // First call SHOULD NOT be cached + result, err, cached := do("key1", expensive) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + if want, have := 1, result.(int); want != have { + t.Fatalf("unexpected value, want %d, have %d", want, have) + } + + if want, have := false, cached; want != have { + t.Fatalf("unexpected caching, want %t, have %t", want, have) + } + + // Second call on same key SHOULD be cached + result, err, cached = do("key1", expensive) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + if want, have := 1, result.(int); want != have { + t.Fatalf("unexpected value, want %d, have %d", want, have) + } + + if want, have := true, cached; want != have { + t.Fatalf("unexpected caching, want %t, have %t", want, have) + } + + // First call on a new key SHOULD NOT be cached + result, err, cached = do("key2", expensive) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + if want, have := 2, result.(int); want != have { + t.Fatalf("unexpected value, want %d, have %d", want, have) + } + + if want, have := false, cached; want != have { + t.Fatalf("unexpected caching, want %t, have %t", want, have) + } +} + +func TestFailedCall(t *testing.T) { + do := makeDoer(newCache(), &singleflight.Group{}) + + calls := 0 + + // This function will fail IFF it has not been called before. + twoForTheMoney := func() (interface{}, error) { + calls++ + + if calls == 1 { + return calls, errors.New("Try again") + } else { + return calls, nil + } + } + + // First call should fail, and not be cached + result, err, cached := do("key1", twoForTheMoney) + if err == nil { + t.Fatalf("expected error") + } + + if want, have := 1, result.(int); want != have { + t.Fatalf("unexpected value, want %d, have %d", want, have) + } + + if want, have := false, cached; want != have { + t.Fatalf("unexpected caching, want %t, have %t", want, have) + } + + // Second call should succeed, and not be cached + result, err, cached = do("key1", twoForTheMoney) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + if want, have := 2, result.(int); want != have { + t.Fatalf("unexpected value, want %d, have %d", want, have) + } + + if want, have := false, cached; want != have { + t.Fatalf("unexpected caching, want %t, have %t", want, have) + } + + // Third call should succeed, and be cached + result, err, cached = do("key1", twoForTheMoney) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + if want, have := 2, result.(int); want != have { + t.Fatalf("unexpected value, want %d, have %d", want, have) + } + + if want, have := true, cached; want != have { + t.Fatalf("unexpected caching, want %t, have %t", want, have) + } +} diff --git a/internal/memoize/noop.go b/internal/memoize/noop.go new file mode 100644 index 000000000..f2e7a6063 --- /dev/null +++ b/internal/memoize/noop.go @@ -0,0 +1,10 @@ +// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +//go:build tinygo + +package memoize + +func Do(_ string, fn func() (interface{}, error)) (interface{}, error) { + return fn() +} diff --git a/internal/operators/restpath.go b/internal/operators/restpath.go index 526cb35bf..f1e4a8911 100644 --- a/internal/operators/restpath.go +++ b/internal/operators/restpath.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" + "github.com/corazawaf/coraza/v3/internal/memoize" ) var rePathTokenRe = regexp.MustCompile(`\{([^\}]+)\}`) @@ -30,11 +31,12 @@ func newRESTPath(options plugintypes.OperatorOptions) (plugintypes.Operator, err for _, token := range rePathTokenRe.FindAllStringSubmatch(data, -1) { data = strings.Replace(data, token[0], fmt.Sprintf("(?P<%s>.*)", token[1]), 1) } - re, err := regexp.Compile(data) + + re, err := memoize.Do(data, func() (interface{}, error) { return regexp.Compile(data) }) if err != nil { return nil, err } - return &restpath{re: re}, nil + return &restpath{re: re.(*regexp.Regexp)}, nil } func (o *restpath) Evaluate(tx plugintypes.TransactionState, value string) bool { diff --git a/internal/operators/rx.go b/internal/operators/rx.go index 147890e1c..e801c9f72 100644 --- a/internal/operators/rx.go +++ b/internal/operators/rx.go @@ -14,6 +14,7 @@ import ( "rsc.io/binaryregexp" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" + "github.com/corazawaf/coraza/v3/internal/memoize" ) type rx struct { @@ -35,15 +36,14 @@ func newRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) { return newBinaryRX(options) } - re, err := regexp.Compile(data) + re, err := memoize.Do(data, func() (interface{}, error) { return regexp.Compile(data) }) if err != nil { return nil, err } - return &rx{re: re}, nil + return &rx{re: re.(*regexp.Regexp)}, nil } func (o *rx) Evaluate(tx plugintypes.TransactionState, value string) bool { - if tx.Capturing() { match := o.re.FindStringSubmatch(value) if len(match) == 0 { @@ -72,15 +72,14 @@ var _ plugintypes.Operator = (*binaryRX)(nil) func newBinaryRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) { data := options.Arguments - re, err := binaryregexp.Compile(data) + re, err := memoize.Do(data, func() (interface{}, error) { return binaryregexp.Compile(data) }) if err != nil { return nil, err } - return &binaryRX{re: re}, nil + return &binaryRX{re: re.(*binaryregexp.Regexp)}, nil } func (o *binaryRX) Evaluate(tx plugintypes.TransactionState, value string) bool { - if tx.Capturing() { match := o.re.FindStringSubmatch(value) if len(match) == 0 { diff --git a/internal/operators/validate_nid.go b/internal/operators/validate_nid.go index ce6192a2b..383f160b6 100644 --- a/internal/operators/validate_nid.go +++ b/internal/operators/validate_nid.go @@ -12,6 +12,7 @@ import ( "strings" "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" + "github.com/corazawaf/coraza/v3/internal/memoize" ) type validateNidFunction = func(input string) bool @@ -39,11 +40,13 @@ func newValidateNID(options plugintypes.OperatorOptions) (plugintypes.Operator, default: return nil, fmt.Errorf("invalid @validateNid argument") } - re, err := regexp.Compile(expr) + + re, err := memoize.Do(expr, func() (interface{}, error) { return regexp.Compile(expr) }) if err != nil { return nil, err } - return &validateNid{fn: fn, re: re}, nil + + return &validateNid{fn: fn, re: re.(*regexp.Regexp)}, nil } func (o *validateNid) Evaluate(tx plugintypes.TransactionState, value string) bool { diff --git a/internal/seclang/directives.go b/internal/seclang/directives.go index 8b0628f0a..7a158c349 100644 --- a/internal/seclang/directives.go +++ b/internal/seclang/directives.go @@ -16,6 +16,7 @@ import ( "github.com/corazawaf/coraza/v3/debuglog" "github.com/corazawaf/coraza/v3/internal/auditlog" "github.com/corazawaf/coraza/v3/internal/corazawaf" + "github.com/corazawaf/coraza/v3/internal/memoize" utils "github.com/corazawaf/coraza/v3/internal/strings" "github.com/corazawaf/coraza/v3/types" ) @@ -731,9 +732,13 @@ func directiveSecAuditLogRelevantStatus(options *DirectiveOptions) error { return errEmptyOptions } - var err error - options.WAF.AuditLogRelevantStatus, err = regexp.Compile(options.Opts) - return err + re, err := memoize.Do(options.Opts, func() (interface{}, error) { return regexp.Compile(options.Opts) }) + if err != nil { + return err + } + + options.WAF.AuditLogRelevantStatus = re.(*regexp.Regexp) + return nil } // Description: Defines which parts of each transaction are going to be recorded