Skip to content

Commit

Permalink
Allow passing own compiler to oracle (#7354)
Browse files Browse the repository at this point in the history
Verified that this works in Regal, but also added a trivial
test to assert that a custom compiler passed is used.

Signed-off-by: Anders Eknert <anders@styra.com>
  • Loading branch information
anderseknert authored Feb 10, 2025
1 parent 9546567 commit af64edb
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 43 deletions.
95 changes: 55 additions & 40 deletions v1/ast/oracle/oracle.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package oracle

import (
"errors"

"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/util"
)

// Error defines the structure of errors returned by the oracle.
Expand All @@ -17,6 +21,7 @@ func (e Error) Error() string {

// Oracle implements different queries over ASTs, e.g., find definition.
type Oracle struct {
compiler *ast.Compiler
}

// New returns a new Oracle object.
Expand All @@ -26,10 +31,10 @@ func New() *Oracle {

// DefinitionQuery defines a Rego definition query.
type DefinitionQuery struct {
Filename string // name of file to search for position inside of
Pos int // position to search for
Modules map[string]*ast.Module // workspace modules; buffer may shadow a file inside the workspace
Filename string // name of file to search for position inside of
Buffer []byte // buffer that overrides module with filename
Pos int // position to search for
}

var (
Expand All @@ -45,6 +50,14 @@ type DefinitionQueryResult struct {
Result *ast.Location `json:"result"`
}

// WithCompiler sets the compiler to use for the oracle. If not set, a new ast.Compiler
// will be created when needed.
func (o *Oracle) WithCompiler(compiler *ast.Compiler) *Oracle {
o.compiler = compiler

return o
}

// FindDefinition returns the location of the definition referred to by the symbol
// at the position in q.
func (o *Oracle) FindDefinition(q DefinitionQuery) (*DefinitionQueryResult, error) {
Expand All @@ -54,7 +67,7 @@ func (o *Oracle) FindDefinition(q DefinitionQuery) (*DefinitionQueryResult, erro
// Ditto for caching across runs. Avoid repeating the same work.

// NOTE(sr): "SetRuleTree" because it's needed for compiler.GetRulesExact() below
compiler, parsed, err := compileUpto("SetRuleTree", q.Modules, q.Buffer, q.Filename)
compiler, parsed, err := o.compileUpto("SetRuleTree", q.Modules, q.Buffer, q.Filename)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -111,50 +124,28 @@ func (o *Oracle) FindDefinition(q DefinitionQuery) (*DefinitionQueryResult, erro
return nil, ErrNoDefinitionFound
}

func walkToFirstOccurrence(node ast.Node, needle ast.Var) (match *ast.Term) {
ast.WalkNodes(node, func(x ast.Node) bool {
if match == nil {
switch x := x.(type) {
case *ast.SomeDecl:
// NOTE(tsandall): The visitor doesn't traverse into some decl terms
// so special case here.
for i := range x.Symbols {
if x.Symbols[i].Value.Compare(needle) == 0 {
match = x.Symbols[i]
break
}
}
case *ast.Term:
if x.Value.Compare(needle) == 0 {
match = x
}
func (o *Oracle) compileUpto(stage string, modules map[string]*ast.Module, bs []byte, filename string) (*ast.Compiler, *ast.Module, error) {
var compiler *ast.Compiler
if o.compiler != nil {
compiler = o.compiler
} else {
compiler = ast.NewCompiler()
}

compiler = compiler.WithStageAfter(stage, ast.CompilerStageDefinition{
Name: "halt",
Stage: func(c *ast.Compiler) *ast.Error {
return &ast.Error{
Code: "halt",
}
}
return match != nil
},
})
return match
}

func compileUpto(stage string, modules map[string]*ast.Module, bs []byte, filename string) (*ast.Compiler, *ast.Module, error) {

compiler := ast.NewCompiler()

if stage != "" {
compiler = compiler.WithStageAfter(stage, ast.CompilerStageDefinition{
Name: "halt",
Stage: func(_ *ast.Compiler) *ast.Error {
return &ast.Error{
Code: "halt",
}
},
})
}

var module *ast.Module

if len(bs) > 0 {
var err error
module, err = ast.ParseModule(filename, string(bs))
module, err = ast.ParseModule(filename, util.ByteSliceToString(bs))
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -192,6 +183,30 @@ func halted(c *ast.Compiler) error {
return errors.New("unreachable: did not halt")
}

func walkToFirstOccurrence(node ast.Node, needle ast.Var) (match *ast.Term) {
ast.WalkNodes(node, func(x ast.Node) bool {
if match == nil {
switch x := x.(type) {
case *ast.SomeDecl:
// NOTE(tsandall): The visitor doesn't traverse into some decl terms
// so special case here.
for i := range x.Symbols {
if x.Symbols[i].Value.Compare(needle) == 0 {
match = x.Symbols[i]
break
}
}
case *ast.Term:
if x.Value.Compare(needle) == 0 {
match = x
}
}
}
return match != nil
})
return match
}

func findContainingNodeStack(module *ast.Module, pos int) []ast.Node {

var matches []ast.Node
Expand Down
22 changes: 19 additions & 3 deletions v1/ast/oracle/oracle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/metrics"
)

func TestOracleFindDefinitionErrors(t *testing.T) {
Expand Down Expand Up @@ -484,7 +485,7 @@ q = true`

func TestCompileUptoNoModules(t *testing.T) {

compiler, module, err := compileUpto("SetRuleTree", nil, []byte("package test\np=1"), "test.rego")
compiler, module, err := New().compileUpto("SetRuleTree", nil, []byte("package test\np=1"), "test.rego")
if err != nil {
t.Fatal(err)
}
Expand All @@ -502,7 +503,7 @@ func TestCompileUptoNoModules(t *testing.T) {

func TestCompileUptoNoBuffer(t *testing.T) {

compiler, module, err := compileUpto("SetRuleTree", map[string]*ast.Module{
compiler, module, err := New().compileUpto("SetRuleTree", map[string]*ast.Module{
"test.rego": ast.MustParseModule("package test\np=1"),
}, nil, "test.rego")
if err != nil {
Expand All @@ -522,11 +523,26 @@ func TestCompileUptoNoBuffer(t *testing.T) {

func TestCompileUptoBadStageName(t *testing.T) {

_, _, err := compileUpto("DEADBEEF", map[string]*ast.Module{
_, _, err := New().compileUpto("DEADBEEF", map[string]*ast.Module{
"test.rego": ast.MustParseModule("package test\np=1"),
}, nil, "test.rego")

if err.Error() != "unreachable: did not halt" {
t.Fatal("expected halt error but got:", err)
}
}

func TestUsingCustomCompiler(t *testing.T) {
m := metrics.New()
o := New().WithCompiler(ast.NewCompiler().WithMetrics(m))
q := DefinitionQuery{Modules: map[string]*ast.Module{"test.rego": ast.MustParseModule("package test\np=1")}}

if _, err := o.FindDefinition(q); !errors.Is(err, ErrNoMatchFound) {
t.Fatal("expected no definition found error but got:", err)
}

// Ensure metrics set on the custom compiler have been updated
if m.Timer("compile_stage_check_imports").Int64() == 0 {
t.Fatal("expected metrics to be updated")
}
}

0 comments on commit af64edb

Please sign in to comment.