Skip to content

Commit

Permalink
templating
Browse files Browse the repository at this point in the history
Signed-off-by: Charles-Edouard Brétéché <charles.edouard@nirmata.com>
  • Loading branch information
eddycharly committed Sep 20, 2024
1 parent 6f1f9c6 commit 20d691f
Show file tree
Hide file tree
Showing 21 changed files with 315 additions and 241 deletions.
18 changes: 4 additions & 14 deletions pkg/apis/policy/v1alpha1/assertion_tree.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package v1alpha1

import (
"context"
"sync"

"github.com/kyverno/kyverno-json/pkg/core/assertion"
"github.com/kyverno/kyverno-json/pkg/core/templating"
"k8s.io/apimachinery/pkg/util/json"
)

Expand All @@ -13,24 +11,20 @@ import (
// +kubebuilder:validation:Type:=""
// AssertionTree represents an assertion tree.
type AssertionTree struct {
_tree any
_assertion func() (assertion.Assertion, error)
_tree any
}

func NewAssertionTree(value any) AssertionTree {
return AssertionTree{
_tree: value,
_assertion: sync.OnceValues(func() (assertion.Assertion, error) {
return assertion.Parse(context.Background(), value)
}),
}
}

func (t *AssertionTree) Assertion() (assertion.Assertion, error) {
func (t *AssertionTree) Assertion(compiler templating.Compiler) (assertion.Assertion, error) {
if t._tree == nil {
return nil, nil
}
return t._assertion()
return assertion.Parse(t._tree, compiler)
}

func (a *AssertionTree) MarshalJSON() ([]byte, error) {
Expand All @@ -44,13 +38,9 @@ func (a *AssertionTree) UnmarshalJSON(data []byte) error {
return err
}
a._tree = v
a._assertion = sync.OnceValues(func() (assertion.Assertion, error) {
return assertion.Parse(context.Background(), v)
})
return nil
}

func (in *AssertionTree) DeepCopyInto(out *AssertionTree) {
out._tree = deepCopy(in._tree)
out._assertion = in._assertion
}
6 changes: 3 additions & 3 deletions pkg/commands/jp/query/command.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package query

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand All @@ -11,7 +10,7 @@ import (

"github.com/jmespath-community/go-jmespath/pkg/parsing"
"github.com/kyverno/kyverno-json/pkg/command"
"github.com/kyverno/kyverno-json/pkg/engine/template"
"github.com/kyverno/kyverno-json/pkg/core/templating"
"github.com/spf13/cobra"
"sigs.k8s.io/yaml"
)
Expand Down Expand Up @@ -156,7 +155,8 @@ func loadInput(cmd *cobra.Command, file string) (any, error) {
}

func evaluate(input any, query string) (any, error) {
result, err := template.ExecuteJP(context.Background(), query, input, nil)
compiler := templating.NewCompiler(templating.CompilerOptions{})
result, err := templating.ExecuteJP(query, input, nil, compiler)
if err != nil {
if syntaxError, ok := err.(parsing.SyntaxError); ok {
return nil, fmt.Errorf("%s\n%s", syntaxError, syntaxError.HighlightLocation())
Expand Down
5 changes: 3 additions & 2 deletions pkg/commands/scan/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strings"

"github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1"
"github.com/kyverno/kyverno-json/pkg/engine/template"
"github.com/kyverno/kyverno-json/pkg/core/templating"
jsonengine "github.com/kyverno/kyverno-json/pkg/json-engine"
"github.com/kyverno/kyverno-json/pkg/payload"
"github.com/kyverno/kyverno-json/pkg/policy"
Expand Down Expand Up @@ -76,8 +76,9 @@ func (c *options) run(cmd *cobra.Command, _ []string) error {
return errors.New("payload is `null`")
}
out.println("Pre processing ...")
compiler := templating.NewCompiler(templating.CompilerOptions{})
for _, preprocessor := range c.preprocessors {
result, err := template.ExecuteJP(context.Background(), preprocessor, payload, nil)
result, err := templating.ExecuteJP(preprocessor, payload, nil, compiler)
if err != nil {
return err
}
Expand Down
73 changes: 37 additions & 36 deletions pkg/core/assertion/assertion.go
Original file line number Diff line number Diff line change
@@ -1,58 +1,56 @@
package assertion

import (
"context"
"errors"
"fmt"
"reflect"
"sync"

"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/jmespath-community/go-jmespath/pkg/parsing"
"github.com/kyverno/kyverno-json/pkg/core/expression"
"github.com/kyverno/kyverno-json/pkg/core/projection"
"github.com/kyverno/kyverno-json/pkg/core/templating"
"github.com/kyverno/kyverno-json/pkg/engine/match"
"github.com/kyverno/kyverno-json/pkg/engine/template"
reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect"
"k8s.io/apimachinery/pkg/util/validation/field"
)

type Assertion interface {
Assert(context.Context, *field.Path, any, binding.Bindings, ...template.Option) (field.ErrorList, error)
Assert(*field.Path, any, binding.Bindings) (field.ErrorList, error)
}

func Parse(ctx context.Context, assertion any) (node, error) {
func Parse(assertion any, compiler templating.Compiler) (node, error) {
switch reflectutils.GetKind(assertion) {
case reflect.Slice:
return parseSlice(ctx, assertion)
return parseSlice(assertion, compiler)
case reflect.Map:
return parseMap(ctx, assertion)
return parseMap(assertion, compiler)
default:
return parseScalar(ctx, assertion)
return parseScalar(assertion, compiler)
}
}

// node implements the Assertion interface using a delegate func
type node func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error)
type node func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error)

func (n node) Assert(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
return n(ctx, path, value, bindings, opts...)
func (n node) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
return n(path, value, bindings)
}

// parseSlice is the assertion represented by a slice.
// it first compares the length of the analysed resource with the length of the descendants.
// if lengths match all descendants are evaluated with their corresponding items.
func parseSlice(ctx context.Context, assertion any) (node, error) {
func parseSlice(assertion any, compiler templating.Compiler) (node, error) {
var assertions []node
valueOf := reflect.ValueOf(assertion)
for i := 0; i < valueOf.Len(); i++ {
sub, err := Parse(ctx, valueOf.Index(i).Interface())
sub, err := Parse(valueOf.Index(i).Interface(), compiler)
if err != nil {
return nil, err
}
assertions = append(assertions, sub)
}
return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
var errs field.ErrorList
if value == nil {
errs = append(errs, field.Invalid(path, value, "value is null"))
Expand All @@ -64,7 +62,7 @@ func parseSlice(ctx context.Context, assertion any) (node, error) {
errs = append(errs, field.Invalid(path, value, "lengths of slices don't match"))
} else {
for i := range assertions {
if _errs, err := assertions[i].Assert(ctx, path.Index(i), valueOf.Index(i).Interface(), bindings, opts...); err != nil {
if _errs, err := assertions[i].Assert(path.Index(i), valueOf.Index(i).Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
Expand All @@ -78,7 +76,7 @@ func parseSlice(ctx context.Context, assertion any) (node, error) {

// parseMap is the assertion represented by a map.
// it is responsible for projecting the analysed resource and passing the result to the descendant
func parseMap(ctx context.Context, assertion any) (node, error) {
func parseMap(assertion any, compiler templating.Compiler) (node, error) {
assertions := map[any]struct {
projection.Projection
node
Expand All @@ -87,16 +85,16 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
for iter.Next() {
key := iter.Key().Interface()
value := iter.Value().Interface()
assertion, err := Parse(ctx, value)
assertion, err := Parse(value, compiler)
if err != nil {
return nil, err
}
entry := assertions[key]
entry.node = assertion
entry.Projection = projection.Parse(key)
entry.Projection = projection.Parse(key, compiler)
assertions[key] = entry
}
return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
var errs field.ErrorList
// if we assert against an empty object, value is expected to be not nil
if len(assertions) == 0 {
Expand All @@ -106,7 +104,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
return errs, nil
}
for k, v := range assertions {
projected, found, err := v.Projection.Handler(ctx, value, bindings, opts...)
projected, found, err := v.Projection.Handler(value, bindings)
if err != nil {
return nil, field.InternalError(path.Child(fmt.Sprint(k)), err)
} else if !found {
Expand All @@ -124,7 +122,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
if v.Projection.ForeachName != "" {
bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(i))
}
if _errs, err := v.Assert(ctx, path.Child(fmt.Sprint(k)).Index(i), valueOf.Index(i).Interface(), bindings, opts...); err != nil {
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Index(i), valueOf.Index(i).Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
Expand All @@ -138,7 +136,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
if v.Projection.ForeachName != "" {
bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(key))
}
if _errs, err := v.Assert(ctx, path.Child(fmt.Sprint(k)).Key(fmt.Sprint(key)), iter.Value().Interface(), bindings, opts...); err != nil {
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Key(fmt.Sprint(key)), iter.Value().Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
Expand All @@ -148,7 +146,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
return nil, field.TypeInvalid(path.Child(fmt.Sprint(k)), projected, "expected a slice or a map")
}
} else {
if _errs, err := v.Assert(ctx, path.Child(fmt.Sprint(k)), projected, bindings, opts...); err != nil {
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)), projected, bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
Expand All @@ -163,8 +161,8 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
// parseScalar is the assertion represented by a leaf.
// it receives a value and compares it with an expected value.
// the expected value can be the result of an expression.
func parseScalar(_ context.Context, assertion any) (node, error) {
var project func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, error)
func parseScalar(assertion any, compiler templating.Compiler) (node, error) {
var project func(value any, bindings binding.Bindings) (any, error)
switch typed := assertion.(type) {
case string:
expr := expression.Parse(typed)
Expand All @@ -176,36 +174,39 @@ func parseScalar(_ context.Context, assertion any) (node, error) {
}
switch expr.Engine {
case expression.EngineJP:
parse := sync.OnceValues(func() (parsing.ASTNode, error) {
parser := parsing.NewParser()
return parser.Parse(expr.Statement)
parse := sync.OnceValues(func() (templating.Program, error) {
return compiler.CompileJP(expr.Statement)
})
project = func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, error) {
ast, err := parse()
project = func(value any, bindings binding.Bindings) (any, error) {
program, err := parse()
if err != nil {
return nil, err
}
return template.ExecuteAST(ctx, ast, value, bindings, opts...)
return program(value, bindings)
}
case expression.EngineCEL:
project = func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, error) {
return template.ExecuteCEL(ctx, expr.Statement, value, bindings)
project = func(value any, bindings binding.Bindings) (any, error) {
program, err := compiler.CompileCEL(expr.Statement)
if err != nil {
return nil, err
}
return program(value, bindings)
}
default:
assertion = expr.Statement
}
}
return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
expected := assertion
if project != nil {
projected, err := project(ctx, value, bindings, opts...)
projected, err := project(value, bindings)
if err != nil {
return nil, field.InternalError(path, err)
}
expected = projected
}
var errs field.ErrorList
if match, err := match.Match(ctx, expected, value); err != nil {
if match, err := match.Match(expected, value); err != nil {
return nil, field.InternalError(path, err)
} else if !match {
errs = append(errs, field.Invalid(path, value, expectValueMessage(expected)))
Expand Down
7 changes: 4 additions & 3 deletions pkg/core/assertion/assertion_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package assertion

import (
"context"
"testing"

"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/kyverno/kyverno-json/pkg/core/templating"
tassert "github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/util/validation/field"
)
Expand Down Expand Up @@ -48,9 +48,10 @@ func TestAssert(t *testing.T) {
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := Parse(context.TODO(), tt.assertion)
compiler := templating.NewCompiler(templating.CompilerOptions{})
parsed, err := Parse(tt.assertion, compiler)
tassert.NoError(t, err)
got, err := parsed.Assert(context.TODO(), nil, tt.value, tt.bindings)
got, err := parsed.Assert(nil, tt.value, tt.bindings)
if tt.wantErr {
tassert.Error(t, err)
} else {
Expand Down
17 changes: 8 additions & 9 deletions pkg/core/message/message.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
package message

import (
"context"
"fmt"
"regexp"
"strings"
"sync"

"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/jmespath-community/go-jmespath/pkg/parsing"
"github.com/kyverno/kyverno-json/pkg/engine/template"
"github.com/kyverno/kyverno-json/pkg/core/templating/jp"
)

var variable = regexp.MustCompile(`{{(.*?)}}`)

type Message interface {
Original() string
Format(any, binding.Bindings, ...template.Option) string
Format(any, binding.Bindings, ...jp.Option) string
}

type substitution = func(string, any, binding.Bindings, ...template.Option) string
type substitution = func(string, any, binding.Bindings, ...jp.Option) string

type message struct {
original string
Expand All @@ -30,7 +29,7 @@ func (m *message) Original() string {
return m.original
}

func (m *message) Format(value any, bindings binding.Bindings, opts ...template.Option) string {
func (m *message) Format(value any, bindings binding.Bindings, opts ...jp.Option) string {
out := m.original
for _, substitution := range m.substitutions {
out = substitution(out, value, bindings, opts...)
Expand All @@ -40,22 +39,22 @@ func (m *message) Format(value any, bindings binding.Bindings, opts ...template.

func Parse(in string) *message {
groups := variable.FindAllStringSubmatch(in, -1)
var substitutions []func(string, any, binding.Bindings, ...template.Option) string
var substitutions []func(string, any, binding.Bindings, ...jp.Option) string
for _, group := range groups {
statement := strings.TrimSpace(group[1])
parse := sync.OnceValues(func() (parsing.ASTNode, error) {
parser := parsing.NewParser()
return parser.Parse(statement)
})
evaluate := func(value any, bindings binding.Bindings, opts ...template.Option) (any, error) {
evaluate := func(value any, bindings binding.Bindings, opts ...jp.Option) (any, error) {
ast, err := parse()
if err != nil {
return nil, err
}
return template.ExecuteAST(context.TODO(), ast, value, bindings, opts...)
return jp.Execute(ast, value, bindings, opts...)
}
placeholder := group[0]
substitutions = append(substitutions, func(out string, value any, bindings binding.Bindings, opts ...template.Option) string {
substitutions = append(substitutions, func(out string, value any, bindings binding.Bindings, opts ...jp.Option) string {
result, err := evaluate(value, bindings, opts...)
if err != nil {
out = strings.ReplaceAll(out, placeholder, fmt.Sprintf("ERR (%s - %s)", statement, err))
Expand Down
Loading

0 comments on commit 20d691f

Please sign in to comment.