Skip to content

Commit

Permalink
Fix relative path bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
ccampo133 committed Apr 8, 2024
1 parent 87db7a6 commit c4628c8
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 21 deletions.
74 changes: 56 additions & 18 deletions classification/label.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func NewLabel(name, description, classificationRule string, tags ...string) (Lab
// error in reality, as the embedded labels should always be valid. If it does,
// it indicates a problem with the embedded labels!
func GetPredefinedLabels() ([]Label, error) {
return getLabels("labels/labels.yaml", predefinedLabelsFs)
return getLabels("labels/labels.yaml", true)
}

// GetCustomLabels loads and returns the labels and their classification rules
Expand All @@ -91,19 +91,36 @@ func GetCustomLabels(labelsYamlFname string) ([]Label, error) {
if err != nil {
return nil, fmt.Errorf("error getting absolute path for labels yaml file %s: %w", labelsYamlFname, err)
}
labelFs := os.DirFS(filepath.Dir(path))
return getLabels(filepath.Base(path), labelFs.(fs.ReadFileFS))
return getLabels(path, false)
}

func getLabels(fname string, labelsFs fs.ReadFileFS) ([]Label, error) {
// getLabels reads the labels YAML file from the given path and returns the
// labels and their classification rules. If predefined is true, the labels are
// read from the embedded FS, otherwise they are read from the file system. If
// there is an error reading or unmarshalling the labels file, it is returned.
// If there are errors reading or parsing a classification rules for labels, the
// errors are aggregated into an InvalidLabelsError and returned, along with the
// labels that were successfully read.
func getLabels(path string, predefined bool) ([]Label, error) {
var (
labelsFs fs.ReadFileFS
labelsFname string
)
if predefined {
labelsFs = predefinedLabelsFs
labelsFname = path
} else {
labelsFs = os.DirFS(filepath.Dir(path)).(fs.ReadFileFS)
labelsFname = filepath.Base(path)
}
// Read and parse the labels yaml file.
yamlBytes, err := labelsFs.ReadFile(fname)
yamlBytes, err := labelsFs.ReadFile(labelsFname)
if err != nil {
return nil, fmt.Errorf("error reading label yaml file %s", fname)
return nil, fmt.Errorf("error reading label yaml file %s", path)
}
type yamlLabel struct {
Label
Rule string `yaml:"rule"`
Label `yaml:",inline"`
Rule string `yaml:"rule"`
}
yamlLabels := make(map[string]yamlLabel)
if err := yaml.Unmarshal(yamlBytes, &yamlLabels); err != nil {
Expand All @@ -113,16 +130,37 @@ func getLabels(fname string, labelsFs fs.ReadFileFS) ([]Label, error) {
// Read each label's classification rule.
var errs []error
for name, lbl := range yamlLabels {
// We assume that the rule will be defined with a relative path in the
// labels yaml file. If the rule is absolute, we create a new fs for
// the rule's directory. This implies that embedded labels will not work
// if the rule is absolute, but that should not be a problem because
// they should always be relative.
ruleFname := filepath.Join(filepath.Dir(fname), lbl.Rule)
ruleFs := labelsFs
if filepath.IsAbs(ruleFname) {
ruleFs = os.DirFS(filepath.Dir(ruleFname)).(fs.ReadFileFS)
ruleFname = filepath.Base(ruleFname)
// The rule file for this label is either an absolute path, a relative
// path, or a predefined rule. We need to determine the fs to use to
// read the rule file.
var (
ruleFs fs.ReadFileFS
ruleFname string
)
if predefined {
// We're dealing with the predefined labels, therefore the rule FS
// is same embedded FS as the labels YAML file. However, we need to
// use the dir of the labels yaml file as the rule file root because
// this is the root of the embedded fs - it's a bit of a quirk with
// the embedded FS API.
ruleFname = filepath.Join(filepath.Dir(path), lbl.Rule)
ruleFs = labelsFs
} else {
ruleFname = filepath.Base(lbl.Rule)
if filepath.IsAbs(lbl.Rule) {
// The rule has an absolute path, so we need to create a new fs
// for the rule's directory.
ruleFs = os.DirFS(filepath.Dir(lbl.Rule)).(fs.ReadFileFS)
} else {
// The rule has a relative path, which is relative to the labels
// YAML file (as opposed to the current directory).
ruleFs = os.DirFS(
filepath.Join(
filepath.Dir(path),
filepath.Dir(lbl.Rule),
),
).(fs.ReadFileFS)
}
}
rule, err := readLabelRule(ruleFname, ruleFs)
if err != nil {
Expand Down
152 changes: 150 additions & 2 deletions classification/label_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package classification

import (
"context"
"fmt"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -19,26 +22,171 @@ func TestGetPredefinedLabels_LabelsAreValid(t *testing.T) {
// against the labels returned by GetPredefinedLabels.
fname := "labels/labels.yaml"
type yamlLabel struct {
Label
Rule string `yaml:"rule"`
Label `yaml:",inline"`
Rule string `yaml:"rule"`
}
yamlBytes, err := predefinedLabelsFs.ReadFile(fname)
require.NoError(t, err)
yamlLabels := make(map[string]yamlLabel)
err = yaml.Unmarshal(yamlBytes, &yamlLabels)
require.NoError(t, err)
for name, lbl := range yamlLabels {
lbl.Name = name
yamlLabels[name] = lbl
}

got, err := GetPredefinedLabels()
require.NoError(t, err)
require.Len(t, got, len(yamlLabels))
for i, lbl := range got {
want := yamlLabels[lbl.Name]
require.Equal(t, want.Name, got[i].Name)
require.Equal(t, want.Description, got[i].Description)
require.ElementsMatch(t, want.Tags, got[i].Tags)
// We don't care about the actual rule content here, just that it was
// loaded. We'll validate the content below.
require.NotNil(t, got[i].ClassificationRule)
}

// Validate the classification rules for each label by doing dummy
// classification. We don't expect any results for an empty input - we
// really just want to validate that the classification rules are valid and
// there were no errors during the classification process.
ctx := context.Background()
classifier, err := NewLabelClassifier(ctx, got...)
require.NoError(t, err)
res, err := classifier.Classify(ctx, map[string]any{})
require.NoError(t, err)
require.Empty(t, res)
}

func TestGetCustomLabels_RelativeRulePath_SameDir(t *testing.T) {
labelsDir := t.TempDir()

// Create the labels YAML file.
labelsYamlFile, err := os.CreateTemp(labelsDir, "labels.yaml")
defer func() { _ = labelsYamlFile.Close() }()
require.NoError(t, err)
labelsYaml := `ADDRESS:
description: Address
rule: address.rego
tags:
- PII`
err = os.WriteFile(labelsYamlFile.Name(), []byte(labelsYaml), os.FileMode(0755))
require.NoError(t, err)

// Create the rule rego file.
ruleFile, err := os.Create(filepath.Join(labelsDir, "address.rego"))
defer func() { _ = ruleFile.Close() }()
require.NoError(t, err)
err = os.WriteFile(ruleFile.Name(), []byte("package foo"), os.FileMode(0755))
require.NoError(t, err)

want := []Label{
{
Name: "ADDRESS",
Description: "Address",
Tags: []string{"PII"},
},
}
got, err := GetCustomLabels(labelsYamlFile.Name())
require.NoError(t, err)
require.Len(t, got, len(want))
for i := range got {
require.Equal(t, want[i].Name, got[i].Name)
require.Equal(t, want[i].Description, got[i].Description)
require.ElementsMatch(t, want[i].Tags, got[i].Tags)
// We don't care about the actual rule content, just that it was loaded.
require.NotNil(t, got[i].ClassificationRule)
}
}

func TestGetCustomLabels_RelativeRulePath_DifferentDir(t *testing.T) {
labelsDir := t.TempDir()
ruleDir := t.TempDir()
relRulPath, err := filepath.Rel(labelsDir, ruleDir)
require.NoError(t, err)

// Create the labels YAML file.
labelsYamlFile, err := os.CreateTemp(labelsDir, "labels.yaml")
defer func() { _ = labelsYamlFile.Close() }()
require.NoError(t, err)
labelsYaml := fmt.Sprintf(
`ADDRESS:
description: Address
rule: %s/address.rego
tags:
- PII`, relRulPath,
)
err = os.WriteFile(labelsYamlFile.Name(), []byte(labelsYaml), os.FileMode(0755))
require.NoError(t, err)

// Create the rule rego file.
ruleFile, err := os.Create(filepath.Join(ruleDir, "address.rego"))
defer func() { _ = ruleFile.Close() }()
require.NoError(t, err)
err = os.WriteFile(ruleFile.Name(), []byte("package foo"), os.FileMode(0755))
require.NoError(t, err)

want := []Label{
{
Name: "ADDRESS",
Description: "Address",
Tags: []string{"PII"},
},
}
got, err := GetCustomLabels(labelsYamlFile.Name())
require.NoError(t, err)
require.Len(t, got, len(want))
for i := range got {
require.Equal(t, want[i].Name, got[i].Name)
require.Equal(t, want[i].Description, got[i].Description)
require.ElementsMatch(t, want[i].Tags, got[i].Tags)
// We don't care about the actual rule content, just that it was loaded.
require.NotNil(t, got[i].ClassificationRule)
}
}

func TestGetCustomLabels_AbsoluteRulePath(t *testing.T) {
labelsDir := t.TempDir()
ruleDir := t.TempDir()

// Create the labels YAML file.
labelsYamlFile, err := os.CreateTemp(labelsDir, "labels.yaml")
defer func() { _ = labelsYamlFile.Close() }()
require.NoError(t, err)
labelsYaml := fmt.Sprintf(
`ADDRESS:
description: Address
rule: %s/address.rego
tags:
- PII`, ruleDir,
)
err = os.WriteFile(labelsYamlFile.Name(), []byte(labelsYaml), os.FileMode(0755))
require.NoError(t, err)

// Create the rule rego file.
ruleFile, err := os.Create(filepath.Join(ruleDir, "address.rego"))
defer func() { _ = ruleFile.Close() }()
require.NoError(t, err)
err = os.WriteFile(ruleFile.Name(), []byte("package foo"), os.FileMode(0755))
require.NoError(t, err)

want := []Label{
{
Name: "ADDRESS",
Description: "Address",
Tags: []string{"PII"},
},
}
got, err := GetCustomLabels(labelsYamlFile.Name())
require.NoError(t, err)
require.Len(t, got, len(want))
for i := range got {
require.Equal(t, want[i].Name, got[i].Name)
require.Equal(t, want[i].Description, got[i].Description)
require.ElementsMatch(t, want[i].Tags, got[i].Tags)
// We don't care about the actual rule content, just that it was loaded.
require.NotNil(t, got[i].ClassificationRule)
}
}
5 changes: 4 additions & 1 deletion sql/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,16 @@ func NewScanner(ctx context.Context, cfg ScannerConfig) (*Scanner, error) {
lbls, err = classification.GetCustomLabels(cfg.LabelsYamlFilename)
}
if err != nil {
errMsg := fmt.Sprintf("error(s) loading data labels")
errMsg := "error(s) loading data labels"
// This error means that some labels weren't loaded due to having
// invalid classification rules. We only log a warning in this case,
// since we still want to proceed with the labels that were
// successfully loaded.
var errs classification.InvalidLabelsError
if errors.As(err, &errs) {
if len(lbls) == 0 {
return nil, fmt.Errorf("%s; no labels were loaded: %w", errMsg, err)
}
log.WithError(errs).Warnf("%s: some labels were not loaded", errMsg)
} else {
return nil, fmt.Errorf("%s: %w", errMsg, err)
Expand Down

0 comments on commit c4628c8

Please sign in to comment.