diff --git a/v1/ast/oracle/oracle.go b/v1/ast/oracle/oracle.go index 801b502a86..d09036a893 100644 --- a/v1/ast/oracle/oracle.go +++ b/v1/ast/oracle/oracle.go @@ -1,9 +1,13 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. package oracle import ( "errors" "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/util" ) // Error defines the structure of errors returned by the oracle. @@ -17,6 +21,7 @@ func (e Error) Error() string { // Oracle implements different queries over ASTs, e.g., find definition. type Oracle struct { + compiler *ast.Compiler } // New returns a new Oracle object. @@ -26,10 +31,10 @@ func New() *Oracle { // DefinitionQuery defines a Rego definition query. type DefinitionQuery struct { - Filename string // name of file to search for position inside of - Pos int // position to search for Modules map[string]*ast.Module // workspace modules; buffer may shadow a file inside the workspace + Filename string // name of file to search for position inside of Buffer []byte // buffer that overrides module with filename + Pos int // position to search for } var ( @@ -45,6 +50,14 @@ type DefinitionQueryResult struct { Result *ast.Location `json:"result"` } +// WithCompiler sets the compiler to use for the oracle. If not set, a new ast.Compiler +// will be created when needed. +func (o *Oracle) WithCompiler(compiler *ast.Compiler) *Oracle { + o.compiler = compiler + + return o +} + // FindDefinition returns the location of the definition referred to by the symbol // at the position in q. func (o *Oracle) FindDefinition(q DefinitionQuery) (*DefinitionQueryResult, error) { @@ -54,7 +67,7 @@ func (o *Oracle) FindDefinition(q DefinitionQuery) (*DefinitionQueryResult, erro // Ditto for caching across runs. Avoid repeating the same work. // NOTE(sr): "SetRuleTree" because it's needed for compiler.GetRulesExact() below - compiler, parsed, err := compileUpto("SetRuleTree", q.Modules, q.Buffer, q.Filename) + compiler, parsed, err := o.compileUpto("SetRuleTree", q.Modules, q.Buffer, q.Filename) if err != nil { return nil, err } @@ -111,50 +124,28 @@ func (o *Oracle) FindDefinition(q DefinitionQuery) (*DefinitionQueryResult, erro return nil, ErrNoDefinitionFound } -func walkToFirstOccurrence(node ast.Node, needle ast.Var) (match *ast.Term) { - ast.WalkNodes(node, func(x ast.Node) bool { - if match == nil { - switch x := x.(type) { - case *ast.SomeDecl: - // NOTE(tsandall): The visitor doesn't traverse into some decl terms - // so special case here. - for i := range x.Symbols { - if x.Symbols[i].Value.Compare(needle) == 0 { - match = x.Symbols[i] - break - } - } - case *ast.Term: - if x.Value.Compare(needle) == 0 { - match = x - } +func (o *Oracle) compileUpto(stage string, modules map[string]*ast.Module, bs []byte, filename string) (*ast.Compiler, *ast.Module, error) { + var compiler *ast.Compiler + if o.compiler != nil { + compiler = o.compiler + } else { + compiler = ast.NewCompiler() + } + + compiler = compiler.WithStageAfter(stage, ast.CompilerStageDefinition{ + Name: "halt", + Stage: func(c *ast.Compiler) *ast.Error { + return &ast.Error{ + Code: "halt", } - } - return match != nil + }, }) - return match -} - -func compileUpto(stage string, modules map[string]*ast.Module, bs []byte, filename string) (*ast.Compiler, *ast.Module, error) { - - compiler := ast.NewCompiler() - - if stage != "" { - compiler = compiler.WithStageAfter(stage, ast.CompilerStageDefinition{ - Name: "halt", - Stage: func(_ *ast.Compiler) *ast.Error { - return &ast.Error{ - Code: "halt", - } - }, - }) - } var module *ast.Module if len(bs) > 0 { var err error - module, err = ast.ParseModule(filename, string(bs)) + module, err = ast.ParseModule(filename, util.ByteSliceToString(bs)) if err != nil { return nil, nil, err } @@ -192,6 +183,30 @@ func halted(c *ast.Compiler) error { return errors.New("unreachable: did not halt") } +func walkToFirstOccurrence(node ast.Node, needle ast.Var) (match *ast.Term) { + ast.WalkNodes(node, func(x ast.Node) bool { + if match == nil { + switch x := x.(type) { + case *ast.SomeDecl: + // NOTE(tsandall): The visitor doesn't traverse into some decl terms + // so special case here. + for i := range x.Symbols { + if x.Symbols[i].Value.Compare(needle) == 0 { + match = x.Symbols[i] + break + } + } + case *ast.Term: + if x.Value.Compare(needle) == 0 { + match = x + } + } + } + return match != nil + }) + return match +} + func findContainingNodeStack(module *ast.Module, pos int) []ast.Node { var matches []ast.Node diff --git a/v1/ast/oracle/oracle_test.go b/v1/ast/oracle/oracle_test.go index ac3737f723..e6094b717b 100644 --- a/v1/ast/oracle/oracle_test.go +++ b/v1/ast/oracle/oracle_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/open-policy-agent/opa/v1/ast" + "github.com/open-policy-agent/opa/v1/metrics" ) func TestOracleFindDefinitionErrors(t *testing.T) { @@ -484,7 +485,7 @@ q = true` func TestCompileUptoNoModules(t *testing.T) { - compiler, module, err := compileUpto("SetRuleTree", nil, []byte("package test\np=1"), "test.rego") + compiler, module, err := New().compileUpto("SetRuleTree", nil, []byte("package test\np=1"), "test.rego") if err != nil { t.Fatal(err) } @@ -502,7 +503,7 @@ func TestCompileUptoNoModules(t *testing.T) { func TestCompileUptoNoBuffer(t *testing.T) { - compiler, module, err := compileUpto("SetRuleTree", map[string]*ast.Module{ + compiler, module, err := New().compileUpto("SetRuleTree", map[string]*ast.Module{ "test.rego": ast.MustParseModule("package test\np=1"), }, nil, "test.rego") if err != nil { @@ -522,7 +523,7 @@ func TestCompileUptoNoBuffer(t *testing.T) { func TestCompileUptoBadStageName(t *testing.T) { - _, _, err := compileUpto("DEADBEEF", map[string]*ast.Module{ + _, _, err := New().compileUpto("DEADBEEF", map[string]*ast.Module{ "test.rego": ast.MustParseModule("package test\np=1"), }, nil, "test.rego") @@ -530,3 +531,18 @@ func TestCompileUptoBadStageName(t *testing.T) { t.Fatal("expected halt error but got:", err) } } + +func TestUsingCustomCompiler(t *testing.T) { + m := metrics.New() + o := New().WithCompiler(ast.NewCompiler().WithMetrics(m)) + q := DefinitionQuery{Modules: map[string]*ast.Module{"test.rego": ast.MustParseModule("package test\np=1")}} + + if _, err := o.FindDefinition(q); !errors.Is(err, ErrNoMatchFound) { + t.Fatal("expected no definition found error but got:", err) + } + + // Ensure metrics set on the custom compiler have been updated + if m.Timer("compile_stage_check_imports").Int64() == 0 { + t.Fatal("expected metrics to be updated") + } +}