diff --git a/contains_in_test.go b/contains_in_test.go new file mode 100644 index 0000000..962a3de --- /dev/null +++ b/contains_in_test.go @@ -0,0 +1,55 @@ +package bexpr + +import ( + "testing" + + "github.com/hashicorp/go-bexpr/grammar" +) + +func TestContainsVsIn(t *testing.T) { + claims := map[string]any{ + "userinfo": map[string]any{ + "groups": "totallynotanadmin", + "email": "admin@company.com", + }, + } + + tests := []struct { + name string + filter string + expect bool + }{ + { + name: "in does not find admin in totallynotanadmin", + filter: `"admin" in "/userinfo/groups"`, + expect: false, + }, + { + name: "contains finds admin in totallynotanadmin", + filter: `"/userinfo/groups" contains "admin"`, + expect: true, + }, + { + name: "contains on email does substring match", + filter: `"/userinfo/email" contains "@company.com"`, + expect: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, _ := grammar.Parse("", []byte(tt.filter)) + expr := ast.(*grammar.MatchExpression) + + eval, _ := CreateEvaluator(tt.filter) + result, _ := eval.Evaluate(claims) + + t.Logf("\n Filter: %s\n Operator: %s\n Result: %v [expect: %v]", + tt.filter, expr.Operator, result, tt.expect) + + if result != tt.expect { + t.Errorf("expected %v, got %v", tt.expect, result) + } + }) + } +} diff --git a/evaluate.go b/evaluate.go index 9c3dfd5..c76357e 100644 --- a/evaluate.go +++ b/evaluate.go @@ -178,13 +178,26 @@ func doMatchIn(expression *grammar.MatchExpression, value reflect.Value) (bool, } case reflect.String: - return strings.Contains(value.String(), matchValue.(string)), nil + return false, nil default: - return false, fmt.Errorf("cannot perform in/contains operations on type %s for selector: %q", kind, expression.Selector) + return false, fmt.Errorf("cannot perform in operations on type %s for selector: %q", kind, expression.Selector) } } +func doMatchContains(expression *grammar.MatchExpression, value reflect.Value) (bool, error) { + if value.Kind() != reflect.String { + return false, fmt.Errorf("contains operator only works on strings, not %s for selector: %q", value.Kind(), expression.Selector) + } + + matchValue, err := getMatchExprValue(expression, reflect.String) + if err != nil { + return false, fmt.Errorf("error getting match value in expression: %w", err) + } + + return strings.Contains(value.String(), matchValue.(string)), nil +} + func doMatchIsEmpty(matcher *grammar.MatchExpression, value reflect.Value) (bool, error) { // NOTE: see preconditions in evaluategrammar.MatchExpressionRecurse return value.Len() == 0, nil @@ -369,6 +382,14 @@ func evaluateMatchExpression(expression *grammar.MatchExpression, datum interfac return !result, nil } return false, err + case grammar.MatchContains: + return doMatchContains(expression, rvalue) + case grammar.MatchNotContains: + result, err := doMatchContains(expression, rvalue) + if err == nil { + return !result, nil + } + return false, err default: return false, fmt.Errorf("invalid match operation: %d", expression.Operator) } diff --git a/grammar/ast.go b/grammar/ast.go index 14a1c5f..b87af37 100644 --- a/grammar/ast.go +++ b/grammar/ast.go @@ -59,6 +59,8 @@ const ( MatchIsNotEmpty MatchMatches MatchNotMatches + MatchContains + MatchNotContains ) func (op MatchOperator) String() string { @@ -79,6 +81,10 @@ func (op MatchOperator) String() string { return "Matches" case MatchNotMatches: return "Not Matches" + case MatchContains: + return "Contains" + case MatchNotContains: + return "Not Contains" default: return "UNKNOWN" } @@ -113,6 +119,10 @@ func (op MatchOperator) NotPresentDisposition() bool { case MatchNotMatches: // M["x"] not matches is true. Nothing matches a missing key return true + case MatchContains: + return false + case MatchNotContains: + return true default: // Should never be reached as every operator should explicitly define its // behavior. diff --git a/grammar/grammar.go b/grammar/grammar.go index ad675ed..8164a02 100644 --- a/grammar/grammar.go +++ b/grammar/grammar.go @@ -2314,7 +2314,7 @@ func (p *parser) callonMatchNotIn1() (any, error) { } func (c *current) onMatchContains1() (any, error) { - return MatchIn, nil + return MatchContains, nil } func (p *parser) callonMatchContains1() (any, error) { @@ -2324,7 +2324,7 @@ func (p *parser) callonMatchContains1() (any, error) { } func (c *current) onMatchNotContains1() (any, error) { - return MatchNotIn, nil + return MatchNotContains, nil } func (p *parser) callonMatchNotContains1() (any, error) { diff --git a/grammar/grammar.peg b/grammar/grammar.peg index d2ff1e2..ef2c933 100644 --- a/grammar/grammar.peg +++ b/grammar/grammar.peg @@ -135,10 +135,10 @@ MatchNotIn <- _ "not" _ "in" _ { return MatchNotIn, nil } MatchContains <- _ "contains" _ { - return MatchIn, nil + return MatchContains, nil } MatchNotContains <- _ "not" _ "contains" _ { - return MatchNotIn, nil + return MatchNotContains, nil } MatchMatches <- _ "matches" _ { return MatchMatches, nil