Skip to content

Commit

Permalink
Add custom label support
Browse files Browse the repository at this point in the history
  • Loading branch information
ccampo133 committed Apr 5, 2024
1 parent db43f9b commit 87db7a6
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 184 deletions.
12 changes: 8 additions & 4 deletions classification/classification.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ package classification
import (
"context"
"encoding/json"

"golang.org/x/exp/maps"
)

// Classifier is an interface that represents a data classifier. A classifier
Expand All @@ -29,11 +27,17 @@ type Classifier interface {
// that attribute was classified as.
type Result map[string]LabelSet

// LabelSet is a set of unique labels.
// LabelSet is a set of unique label names.
type LabelSet map[string]struct{}

// MarshalJSON marshals the LabelSet into a JSON array of strings, where each
// string is the name of a label in the set.
func (l LabelSet) MarshalJSON() ([]byte, error) {
return json.Marshal(maps.Keys(l))
keys := make([]string, 0, len(l))
for k := range l {
keys = append(keys, k)
}
return json.Marshal(keys)
}

// Classification represents the classification of a data repository attribute.
Expand Down
127 changes: 99 additions & 28 deletions classification/label.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,40 @@ package classification

import (
"embed"
"errors"
"fmt"
"strings"
"io/fs"
"os"
"path/filepath"

"github.com/open-policy-agent/opa/ast"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"gopkg.in/yaml.v3"
)

var (
//go:embed labels/*.rego
regoFs embed.FS
//go:embed labels/labels.yaml
labelsYaml string
//go:embed labels/*.rego labels/labels.yaml
predefinedLabelsFs embed.FS
)

// InvalidLabelsError is an error type that represents an error when one or
// more labels are invalid, e.g. they have invalid classification rules. The
// error contains a slice of errors that caused the error, which can be
// unwrapped to get the individual errors that caused the problems.
type InvalidLabelsError struct {
errs []error
}

// Unwrap returns the errors that caused the InvalidLabelsError.
func (e InvalidLabelsError) Unwrap() []error {
return e.errs
}

// Error returns a string representation of the InvalidLabelsError.
func (e InvalidLabelsError) Error() string {
return errors.Join(e.errs...).Error()
}

// Label represents a data classification label.
type Label struct {
// Name is the name of the label.
Expand Down Expand Up @@ -48,36 +66,89 @@ func NewLabel(name, description, classificationRule string, tags ...string) (Lab
}, nil
}

// GetEmbeddedLabels returns the predefined embedded labels and their
// classification rules. The labels are read from the embedded labels.yaml file
// and the classification rules are read from the embedded Rego files. If there
// is an error unmarshalling the labels file, it is returned. If there is an
// error reading or parsing a classification rule for a label, a warning is
// logged and that label is skipped.
func GetEmbeddedLabels() ([]Label, error) {
labels := struct {
Labels map[string]Label `yaml:"labels"`
}{}
if err := yaml.Unmarshal([]byte(labelsYaml), &labels); err != nil {
return nil, fmt.Errorf("error unmarshalling labels.yaml: %w", err)
// GetPredefinedLabels loads and returns the predefined embedded labels and
// their classification rules. The labels are read from the embedded labels.yaml
// file and the classification rules are read from the embedded Rego files. If
// there is an error reading or unmarshalling the labels file, it is returned.
// If there are errors reading or parsing a classification rules for labels, the
// errors are aggregated into an InvalidLabelsError and returned, along with
// the labels that were successfully read. Note that this should not return an
// error in reality, as the embedded labels should always be valid. If it does,
// it indicates a problem with the embedded labels!
func GetPredefinedLabels() ([]Label, error) {
return getLabels("labels/labels.yaml", predefinedLabelsFs)
}

// GetCustomLabels loads and returns the labels and their classification rules
// defined in the given labels yaml file. The labels are read from the file
// along with their classification rule Rego files (defined in the yaml). If
// there is an error unmarshalling the labels file, it is returned. If there are
// errors reading or parsing a classification rules for labels, the errors are
// aggregated into an InvalidLabelsError and returned, along with the labels
// that were successfully read.
func GetCustomLabels(labelsYamlFname string) ([]Label, error) {
path, err := filepath.Abs(labelsYamlFname)
if err != nil {
return nil, fmt.Errorf("error getting absolute path for labels yaml file %s: %w", labelsYamlFname, err)
}
for name, lbl := range labels.Labels {
fname := "labels/" + strings.ReplaceAll(strings.ToLower(name), " ", "_") + ".rego"
b, err := regoFs.ReadFile(fname)
if err != nil {
log.WithError(err).Warnf("error reading rego file %s", fname)
continue
labelFs := os.DirFS(filepath.Dir(path))
return getLabels(filepath.Base(path), labelFs.(fs.ReadFileFS))
}

func getLabels(fname string, labelsFs fs.ReadFileFS) ([]Label, error) {
// Read and parse the labels yaml file.
yamlBytes, err := labelsFs.ReadFile(fname)
if err != nil {
return nil, fmt.Errorf("error reading label yaml file %s", fname)
}
type yamlLabel struct {
Label
Rule string `yaml:"rule"`
}
yamlLabels := make(map[string]yamlLabel)
if err := yaml.Unmarshal(yamlBytes, &yamlLabels); err != nil {
return nil, fmt.Errorf("error unmarshalling labels yaml: %w", err)
}
labels := make([]Label, 0, len(yamlLabels))
// Read each label's classification rule.
var errs []error
for name, lbl := range yamlLabels {
// We assume that the rule will be defined with a relative path in the
// labels yaml file. If the rule is absolute, we create a new fs for
// the rule's directory. This implies that embedded labels will not work
// if the rule is absolute, but that should not be a problem because
// they should always be relative.
ruleFname := filepath.Join(filepath.Dir(fname), lbl.Rule)
ruleFs := labelsFs
if filepath.IsAbs(ruleFname) {
ruleFs = os.DirFS(filepath.Dir(ruleFname)).(fs.ReadFileFS)
ruleFname = filepath.Base(ruleFname)
}
rule, err := parseRego(string(b))
rule, err := readLabelRule(ruleFname, ruleFs)
if err != nil {
log.WithError(err).Warnf("error parsing classification rule for label %s", lbl.Name)
errs = append(errs, fmt.Errorf("error reading classification rule for label %s: %w", name, err))
continue
}
lbl.Name = name
lbl.ClassificationRule = rule
labels.Labels[name] = lbl
labels = append(labels, lbl.Label)
}
if len(errs) > 0 {
return labels, InvalidLabelsError{errs}
}
return labels, nil
}

func readLabelRule(fname string, labelFs fs.ReadFileFS) (*ast.Module, error) {
b, err := labelFs.ReadFile(fname)
if err != nil {
return nil, fmt.Errorf("error reading rego file %s: %w", fname, err)
}
rule, err := parseRego(string(b))
if err != nil {
return nil, fmt.Errorf("error parsing classification rule for file %s: %w", fname, err)
}
return maps.Values(labels.Labels), nil
return rule, nil
}

func parseRego(code string) (*ast.Module, error) {
Expand Down
8 changes: 5 additions & 3 deletions classification/label_classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package classification

import (
"context"
"errors"
"fmt"

"github.com/open-policy-agent/opa/rego"
Expand Down Expand Up @@ -45,12 +46,13 @@ func NewLabelClassifier(ctx context.Context, labels ...Label) (*LabelClassifier,
// names to the set of labels that the attribute was classified as.
func (c *LabelClassifier) Classify(ctx context.Context, input map[string]any) (Result, error) {
result := make(Result, len(c.queries))
var errs error
for lbl, query := range c.queries {
output, err := evalQuery(ctx, query, input)
if err != nil {
// A single error should not prevent the classification of other
// labels. Log the error and continue.
log.WithError(err).Errorf("error evaluating query for label %s", lbl)
// labels. Aggregate the error and continue.
errs = errors.Join(errs, fmt.Errorf("error evaluating query for label %s: %w", lbl, err))
continue
}
log.Debugf("classification results for label %s: %v", lbl, output)
Expand All @@ -67,7 +69,7 @@ func (c *LabelClassifier) Classify(ctx context.Context, input map[string]any) (R
}
}
}
return result, nil
return result, errs
}

// evalQuery evaluates the provided Rego query with the given attributes as input, and returns the classification results. The output is a
Expand Down
6 changes: 3 additions & 3 deletions classification/label_classifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
func TestNewLabelClassifier_Success(t *testing.T) {
lbl, err := NewLabel("foo", "test label", "package foo\noutput = true")
require.NoError(t, err)
classifier, err := NewLabelClassifier(lbl)
classifier, err := NewLabelClassifier(context.Background(), lbl)
require.NoError(t, err)
require.NotNil(t, classifier)
}
Expand Down Expand Up @@ -140,14 +140,14 @@ func newTestLabelClassifier(t *testing.T, lblNames ...string) *LabelClassifier {
for i, lblName := range lblNames {
lbls[i] = newTestLabel(t, lblName)
}
classifier, err := NewLabelClassifier(lbls...)
classifier, err := NewLabelClassifier(context.Background(), lbls...)
require.NoError(t, err)
return classifier
}

func newTestLabel(t *testing.T, lblName string) Label {
fname := "labels/" + strings.ReplaceAll(strings.ToLower(lblName), " ", "_") + ".rego"
fin, err := regoFs.ReadFile(fname)
fin, err := predefinedLabelsFs.ReadFile(fname)
require.NoError(t, err)
classifierCode := string(fin)
lbl, err := NewLabel(lblName, "test label", classifierCode)
Expand Down
79 changes: 29 additions & 50 deletions classification/label_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,43 @@ package classification

import (
"context"
"fmt"
"testing"

"github.com/open-policy-agent/opa/rego"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)

func TestGetEmbeddedLabels(t *testing.T) {
got, err := GetEmbeddedLabels()
// This isn't really a test as much as it is a validation of the embedded labels
// and their classification rules. The test will fail if the labels or their
// classification rules are invalid for any reason. It includes parsing the
// label Rego code, and validating that have their expected output. It serves
// mostly as a build-time a sanity check, which should hopefully avoid us
// releasing a build with broken embedded labels!
func TestGetPredefinedLabels_LabelsAreValid(t *testing.T) {
// We want to read the labels ourselves in this test, so we can compare them
// against the labels returned by GetPredefinedLabels.
fname := "labels/labels.yaml"
type yamlLabel struct {
Label
Rule string `yaml:"rule"`
}
yamlBytes, err := predefinedLabelsFs.ReadFile(fname)
require.NoError(t, err)
require.NotEmpty(t, got)
}

func TestRego(t *testing.T) {

module := `
package example.authz
import rego.v1
default allow := false
allow if {
input.method == "GET"
input.path == ["salary", input.subject.user]
}
allow if is_admin
is_admin if "admin" in input.subject.groups
`

mod, err := parseRego(module)
yamlLabels := make(map[string]yamlLabel)
err = yaml.Unmarshal(yamlBytes, &yamlLabels)
require.NoError(t, err)
require.NotNil(t, mod)
path := mod.Package.Path.String()
fmt.Println(path)

ctx := context.TODO()

query, err := rego.New(
rego.Query("data.example.authz.allow"),
rego.Module("example.rego", module),
).PrepareForEval(ctx)
got, err := GetPredefinedLabels()
require.NoError(t, err)
require.NotNil(t, query)

input := map[string]interface{}{
"method": "GET",
"path": []interface{}{"salary", "bob"},
"subject": map[string]interface{}{
"user": "bob",
"groups": []interface{}{"sales", "marketing"},
},
}

results, err := query.Eval(ctx, rego.EvalInput(input))
require.Len(t, got, len(yamlLabels))

// Validate the classification rules for each label by doing dummy
// classification. We don't expect any results for an empty input - we
// really just want to validate that the classification rules are valid and
// there were no errors during the classification process.
ctx := context.Background()
classifier, err := NewLabelClassifier(ctx, got...)
res, err := classifier.Classify(ctx, map[string]any{})
require.NoError(t, err)
require.NotEmpty(t, results)
require.True(t, results.Allowed())
require.Empty(t, res)
}
14 changes: 7 additions & 7 deletions classification/labels/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
This directory contains all the data label definitions used for classification.
The label metadata is specified in the [`labels.yaml`](labels.yaml) file. Please
see that file's doc comment for more details.
This directory contains all the predefined data label definitions used for
classification. The label metadata is specified in the
[`labels.yaml`](labels.yaml) file, and the classification rules are defined in
individual Rego files for each label.

Additionally, the classification rule Rego code for each label must be specified
as a `<label>.rego` file, where `<label>` is the name of the label in lowercase.
For example, if the label `ADDRESS` is defined in `labels.yaml`, it should have
an `address.rego` file defined as well.
To add a new predefined label, add its metadata to [`labels.yaml`](labels.yaml)
(following the file's instructions), as well as a corresponding classification
rule Rego file.
Loading

0 comments on commit 87db7a6

Please sign in to comment.