Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: adds memoize implementation for regexes and ahocorasick #836

Merged
merged 10 commits into from
Aug 6, 2023
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ require (
github.com/petar-dambovaliev/aho-corasick v0.0.0-20211021192214-5ab2d9280aa9
github.com/tidwall/gjson v1.14.4
golang.org/x/net v0.11.0
golang.org/x/sync v0.1.0
rsc.io/binaryregexp v0.2.0
)

Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
14 changes: 12 additions & 2 deletions internal/corazawaf/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"github.com/corazawaf/coraza/v3/experimental/plugins/macro"
"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
"github.com/corazawaf/coraza/v3/internal/corazarules"
"github.com/corazawaf/coraza/v3/internal/memoize"
"github.com/corazawaf/coraza/v3/types"
"github.com/corazawaf/coraza/v3/types/variables"
)
Expand Down Expand Up @@ -456,7 +457,12 @@
var re *regexp.Regexp
if len(key) > 2 && key[0] == '/' && key[len(key)-1] == '/' {
key = key[1 : len(key)-1]
re = regexp.MustCompile(key)

if vare, err := memoize.Do(key, func() (interface{}, error) { return regexp.Compile(key) }); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be worth extracting a function for the two usages

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean for the regex and binaryregex?

panic(err)

Check warning on line 462 in internal/corazawaf/rule.go

View check run for this annotation

Codecov / codecov/patch

internal/corazawaf/rule.go#L462

Added line #L462 was not covered by tests
} else {
re = vare.(*regexp.Regexp)
}
}

if multiphaseEvaluation {
Expand Down Expand Up @@ -521,7 +527,11 @@
var re *regexp.Regexp
if len(key) > 2 && key[0] == '/' && key[len(key)-1] == '/' {
key = key[1 : len(key)-1]
re = regexp.MustCompile(key)
if vare, err := memoize.Do(key, func() (interface{}, error) { return regexp.Compile(key) }); err != nil {
panic(err)

Check warning on line 531 in internal/corazawaf/rule.go

View check run for this annotation

Codecov / codecov/patch

internal/corazawaf/rule.go#L531

Added line #L531 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should never panic, you can return error here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, however this should be fixed in main, I am just reproducing the MustCompile behaviour.

} else {
re = vare.(*regexp.Regexp)
}
}
// Prevent sigsev
if r == nil {
Expand Down
11 changes: 11 additions & 0 deletions internal/memoize/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Memoize

Memoize allows to cache certain expensive function calls and
cache the result. The main advantage in Coraza is to memoize
the regexes when the connects spins up more than one WAF in
the same process and hence same regexes are being compiled
over and over.

Currently it is opt-in under the `memoize_regex` build tag
Copy link
Member

@M4tteoP M4tteoP Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: what about also adding a one-line description in the main readme, under https://github.com/corazawaf/coraza#build-tags?

as under a misuse it could lead to a memory leak as currently
the cache is global.
47 changes: 47 additions & 0 deletions internal/memoize/memoize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0

//go:build !tinygo && memoize_regex

// https://github.com/kofalt/go-memoize/blob/master/memoize.go

package memoize

import (
"sync"

"golang.org/x/sync/singleflight"
)

var doer = makeDoer(new(sync.Map), new(singleflight.Group))

// Do executes and returns the results of the given function, unless there was a cached
// value of the same key. Only one execution is in-flight for a given key at a time.
// The boolean return value indicates whether v was previously stored.
func Do(key string, fn func() (interface{}, error)) (interface{}, error) {
value, err, _ := doer(key, fn)
return value, err
}

// makeDoer returns a function that executes and returns the results of the given function
func makeDoer(cache *sync.Map, group *singleflight.Group) func(string, func() (interface{}, error)) (interface{}, error, bool) {
return func(key string, fn func() (interface{}, error)) (interface{}, error, bool) {
// Check cache
value, found := cache.Load(key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, can combine two lines

if found {
return value, nil, true
}

// Combine memoized function with a cache store
value, err, _ := group.Do(key, func() (interface{}, error) {
data, innerErr := fn()
if innerErr == nil {
cache.Store(key, data)
}

return data, innerErr
})

return value, err, false
}
}
169 changes: 169 additions & 0 deletions internal/memoize/memoize_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0

//go:build !tinygo && memoize_regex

// https://github.com/kofalt/go-memoize/blob/master/memoize.go

package memoize

import (
"errors"
"sync"
"testing"

"golang.org/x/sync/singleflight"
)

func TestDo(t *testing.T) {
expensiveCalls := 0

// Function tracks how many times its been called
expensive := func() (interface{}, error) {
expensiveCalls++
return expensiveCalls, nil
}

// First call SHOULD NOT be cached
result, err := Do("key1", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}

if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}

// Second call on same key SHOULD be cached
result, err = Do("key1", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}

if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}

// First call on a new key SHOULD NOT be cached
result, err = Do("key2", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}

if want, have := 2, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}
}

func TestSuccessCall(t *testing.T) {
do := makeDoer(new(sync.Map), &singleflight.Group{})

expensiveCalls := 0

// Function tracks how many times its been called
expensive := func() (interface{}, error) {
expensiveCalls++
return expensiveCalls, nil
}

// First call SHOULD NOT be cached
result, err, cached := do("key1", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}

if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}

if want, have := false, cached; want != have {
t.Fatalf("unexpected caching, want %t, have %t", want, have)
}

// Second call on same key SHOULD be cached
result, err, cached = do("key1", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}

if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}

if want, have := true, cached; want != have {
t.Fatalf("unexpected caching, want %t, have %t", want, have)
}

// First call on a new key SHOULD NOT be cached
result, err, cached = do("key2", expensive)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}

if want, have := 2, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}

if want, have := false, cached; want != have {
t.Fatalf("unexpected caching, want %t, have %t", want, have)
}
}

func TestFailedCall(t *testing.T) {
do := makeDoer(new(sync.Map), &singleflight.Group{})

calls := 0

// This function will fail IFF it has not been called before.
twoForTheMoney := func() (interface{}, error) {
calls++

if calls == 1 {
return calls, errors.New("Try again")
} else {
return calls, nil
}
}

// First call should fail, and not be cached
result, err, cached := do("key1", twoForTheMoney)
if err == nil {
t.Fatalf("expected error")
}

if want, have := 1, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}

if want, have := false, cached; want != have {
t.Fatalf("unexpected caching, want %t, have %t", want, have)
}

// Second call should succeed, and not be cached
result, err, cached = do("key1", twoForTheMoney)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}

if want, have := 2, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}

if want, have := false, cached; want != have {
t.Fatalf("unexpected caching, want %t, have %t", want, have)
}

// Third call should succeed, and be cached
result, err, cached = do("key1", twoForTheMoney)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}

if want, have := 2, result.(int); want != have {
t.Fatalf("unexpected value, want %d, have %d", want, have)
}

if want, have := true, cached; want != have {
t.Fatalf("unexpected caching, want %t, have %t", want, have)
}
}
10 changes: 10 additions & 0 deletions internal/memoize/noop.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright 2023 Juan Pablo Tosso and the OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0

//go:build tinygo || !memoize_regex

package memoize

func Do(_ string, fn func() (interface{}, error)) (interface{}, error) {
return fn()
}
6 changes: 4 additions & 2 deletions internal/operators/restpath.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"

"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
"github.com/corazawaf/coraza/v3/internal/memoize"
)

var rePathTokenRe = regexp.MustCompile(`\{([^\}]+)\}`)
Expand All @@ -30,11 +31,12 @@ func newRESTPath(options plugintypes.OperatorOptions) (plugintypes.Operator, err
for _, token := range rePathTokenRe.FindAllStringSubmatch(data, -1) {
data = strings.Replace(data, token[0], fmt.Sprintf("(?P<%s>.*)", token[1]), 1)
}
re, err := regexp.Compile(data)

re, err := memoize.Do(data, func() (interface{}, error) { return regexp.Compile(data) })
if err != nil {
return nil, err
}
return &restpath{re: re}, nil
return &restpath{re: re.(*regexp.Regexp)}, nil
}

func (o *restpath) Evaluate(tx plugintypes.TransactionState, value string) bool {
Expand Down
11 changes: 5 additions & 6 deletions internal/operators/rx.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"rsc.io/binaryregexp"

"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
"github.com/corazawaf/coraza/v3/internal/memoize"
)

type rx struct {
Expand All @@ -35,15 +36,14 @@ func newRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) {
return newBinaryRX(options)
}

re, err := regexp.Compile(data)
re, err := memoize.Do(data, func() (interface{}, error) { return regexp.Compile(data) })
if err != nil {
return nil, err
}
return &rx{re: re}, nil
return &rx{re: re.(*regexp.Regexp)}, nil
}

func (o *rx) Evaluate(tx plugintypes.TransactionState, value string) bool {

if tx.Capturing() {
match := o.re.FindStringSubmatch(value)
if len(match) == 0 {
Expand Down Expand Up @@ -72,15 +72,14 @@ var _ plugintypes.Operator = (*binaryRX)(nil)
func newBinaryRX(options plugintypes.OperatorOptions) (plugintypes.Operator, error) {
data := options.Arguments

re, err := binaryregexp.Compile(data)
re, err := memoize.Do(data, func() (interface{}, error) { return binaryregexp.Compile(data) })
if err != nil {
return nil, err
}
return &binaryRX{re: re}, nil
return &binaryRX{re: re.(*binaryregexp.Regexp)}, nil
}

func (o *binaryRX) Evaluate(tx plugintypes.TransactionState, value string) bool {

if tx.Capturing() {
match := o.re.FindStringSubmatch(value)
if len(match) == 0 {
Expand Down
7 changes: 5 additions & 2 deletions internal/operators/validate_nid.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"strings"

"github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes"
"github.com/corazawaf/coraza/v3/internal/memoize"
)

type validateNidFunction = func(input string) bool
Expand Down Expand Up @@ -39,11 +40,13 @@
default:
return nil, fmt.Errorf("invalid @validateNid argument")
}
re, err := regexp.Compile(expr)

re, err := memoize.Do(expr, func() (interface{}, error) { return regexp.Compile(expr) })

Check warning on line 44 in internal/operators/validate_nid.go

View check run for this annotation

Codecov / codecov/patch

internal/operators/validate_nid.go#L44

Added line #L44 was not covered by tests
if err != nil {
return nil, err
}
return &validateNid{fn: fn, re: re}, nil

return &validateNid{fn: fn, re: re.(*regexp.Regexp)}, nil

Check warning on line 49 in internal/operators/validate_nid.go

View check run for this annotation

Codecov / codecov/patch

internal/operators/validate_nid.go#L49

Added line #L49 was not covered by tests
}

func (o *validateNid) Evaluate(tx plugintypes.TransactionState, value string) bool {
Expand Down
Loading
Loading