diff --git a/classification/label.go b/classification/label.go index 850cc50..138e37b 100644 --- a/classification/label.go +++ b/classification/label.go @@ -76,7 +76,7 @@ func NewLabel(name, description, classificationRule string, tags ...string) (Lab // error in reality, as the embedded labels should always be valid. If it does, // it indicates a problem with the embedded labels! func GetPredefinedLabels() ([]Label, error) { - return getLabels("labels/labels.yaml", predefinedLabelsFs) + return getLabels("labels/labels.yaml", true) } // GetCustomLabels loads and returns the labels and their classification rules @@ -91,19 +91,36 @@ func GetCustomLabels(labelsYamlFname string) ([]Label, error) { if err != nil { return nil, fmt.Errorf("error getting absolute path for labels yaml file %s: %w", labelsYamlFname, err) } - labelFs := os.DirFS(filepath.Dir(path)) - return getLabels(filepath.Base(path), labelFs.(fs.ReadFileFS)) + return getLabels(path, false) } -func getLabels(fname string, labelsFs fs.ReadFileFS) ([]Label, error) { +// getLabels reads the labels YAML file from the given path and returns the +// labels and their classification rules. If predefined is true, the labels are +// read from the embedded FS, otherwise they are read from the file system. If +// there is an error reading or unmarshalling the labels file, it is returned. +// If there are errors reading or parsing a classification rules for labels, the +// errors are aggregated into an InvalidLabelsError and returned, along with the +// labels that were successfully read. +func getLabels(path string, predefined bool) ([]Label, error) { + var ( + labelsFs fs.ReadFileFS + labelsFname string + ) + if predefined { + labelsFs = predefinedLabelsFs + labelsFname = path + } else { + labelsFs = os.DirFS(filepath.Dir(path)).(fs.ReadFileFS) + labelsFname = filepath.Base(path) + } // Read and parse the labels yaml file. - yamlBytes, err := labelsFs.ReadFile(fname) + yamlBytes, err := labelsFs.ReadFile(labelsFname) if err != nil { - return nil, fmt.Errorf("error reading label yaml file %s", fname) + return nil, fmt.Errorf("error reading label yaml file %s", path) } type yamlLabel struct { - Label - Rule string `yaml:"rule"` + Label `yaml:",inline"` + Rule string `yaml:"rule"` } yamlLabels := make(map[string]yamlLabel) if err := yaml.Unmarshal(yamlBytes, &yamlLabels); err != nil { @@ -113,16 +130,37 @@ func getLabels(fname string, labelsFs fs.ReadFileFS) ([]Label, error) { // Read each label's classification rule. var errs []error for name, lbl := range yamlLabels { - // We assume that the rule will be defined with a relative path in the - // labels yaml file. If the rule is absolute, we create a new fs for - // the rule's directory. This implies that embedded labels will not work - // if the rule is absolute, but that should not be a problem because - // they should always be relative. - ruleFname := filepath.Join(filepath.Dir(fname), lbl.Rule) - ruleFs := labelsFs - if filepath.IsAbs(ruleFname) { - ruleFs = os.DirFS(filepath.Dir(ruleFname)).(fs.ReadFileFS) - ruleFname = filepath.Base(ruleFname) + // The rule file for this label is either an absolute path, a relative + // path, or a predefined rule. We need to determine the fs to use to + // read the rule file. + var ( + ruleFs fs.ReadFileFS + ruleFname string + ) + if predefined { + // We're dealing with the predefined labels, therefore the rule FS + // is same embedded FS as the labels YAML file. However, we need to + // use the dir of the labels yaml file as the rule file root because + // this is the root of the embedded fs - it's a bit of a quirk with + // the embedded FS API. + ruleFname = filepath.Join(filepath.Dir(path), lbl.Rule) + ruleFs = labelsFs + } else { + ruleFname = filepath.Base(lbl.Rule) + if filepath.IsAbs(lbl.Rule) { + // The rule has an absolute path, so we need to create a new fs + // for the rule's directory. + ruleFs = os.DirFS(filepath.Dir(lbl.Rule)).(fs.ReadFileFS) + } else { + // The rule has a relative path, which is relative to the labels + // YAML file (as opposed to the current directory). + ruleFs = os.DirFS( + filepath.Join( + filepath.Dir(path), + filepath.Dir(lbl.Rule), + ), + ).(fs.ReadFileFS) + } } rule, err := readLabelRule(ruleFname, ruleFs) if err != nil { diff --git a/classification/label_test.go b/classification/label_test.go index 2262f98..430060a 100644 --- a/classification/label_test.go +++ b/classification/label_test.go @@ -2,6 +2,9 @@ package classification import ( "context" + "fmt" + "os" + "path/filepath" "testing" "github.com/stretchr/testify/require" @@ -19,18 +22,31 @@ func TestGetPredefinedLabels_LabelsAreValid(t *testing.T) { // against the labels returned by GetPredefinedLabels. fname := "labels/labels.yaml" type yamlLabel struct { - Label - Rule string `yaml:"rule"` + Label `yaml:",inline"` + Rule string `yaml:"rule"` } yamlBytes, err := predefinedLabelsFs.ReadFile(fname) require.NoError(t, err) yamlLabels := make(map[string]yamlLabel) err = yaml.Unmarshal(yamlBytes, &yamlLabels) require.NoError(t, err) + for name, lbl := range yamlLabels { + lbl.Name = name + yamlLabels[name] = lbl + } got, err := GetPredefinedLabels() require.NoError(t, err) require.Len(t, got, len(yamlLabels)) + for i, lbl := range got { + want := yamlLabels[lbl.Name] + require.Equal(t, want.Name, got[i].Name) + require.Equal(t, want.Description, got[i].Description) + require.ElementsMatch(t, want.Tags, got[i].Tags) + // We don't care about the actual rule content here, just that it was + // loaded. We'll validate the content below. + require.NotNil(t, got[i].ClassificationRule) + } // Validate the classification rules for each label by doing dummy // classification. We don't expect any results for an empty input - we @@ -38,7 +54,139 @@ func TestGetPredefinedLabels_LabelsAreValid(t *testing.T) { // there were no errors during the classification process. ctx := context.Background() classifier, err := NewLabelClassifier(ctx, got...) + require.NoError(t, err) res, err := classifier.Classify(ctx, map[string]any{}) require.NoError(t, err) require.Empty(t, res) } + +func TestGetCustomLabels_RelativeRulePath_SameDir(t *testing.T) { + labelsDir := t.TempDir() + + // Create the labels YAML file. + labelsYamlFile, err := os.CreateTemp(labelsDir, "labels.yaml") + defer func() { _ = labelsYamlFile.Close() }() + require.NoError(t, err) + labelsYaml := `ADDRESS: + description: Address + rule: address.rego + tags: + - PII` + err = os.WriteFile(labelsYamlFile.Name(), []byte(labelsYaml), os.FileMode(0755)) + require.NoError(t, err) + + // Create the rule rego file. + ruleFile, err := os.Create(filepath.Join(labelsDir, "address.rego")) + defer func() { _ = ruleFile.Close() }() + require.NoError(t, err) + err = os.WriteFile(ruleFile.Name(), []byte("package foo"), os.FileMode(0755)) + require.NoError(t, err) + + want := []Label{ + { + Name: "ADDRESS", + Description: "Address", + Tags: []string{"PII"}, + }, + } + got, err := GetCustomLabels(labelsYamlFile.Name()) + require.NoError(t, err) + require.Len(t, got, len(want)) + for i := range got { + require.Equal(t, want[i].Name, got[i].Name) + require.Equal(t, want[i].Description, got[i].Description) + require.ElementsMatch(t, want[i].Tags, got[i].Tags) + // We don't care about the actual rule content, just that it was loaded. + require.NotNil(t, got[i].ClassificationRule) + } +} + +func TestGetCustomLabels_RelativeRulePath_DifferentDir(t *testing.T) { + labelsDir := t.TempDir() + ruleDir := t.TempDir() + relRulPath, err := filepath.Rel(labelsDir, ruleDir) + require.NoError(t, err) + + // Create the labels YAML file. + labelsYamlFile, err := os.CreateTemp(labelsDir, "labels.yaml") + defer func() { _ = labelsYamlFile.Close() }() + require.NoError(t, err) + labelsYaml := fmt.Sprintf( + `ADDRESS: + description: Address + rule: %s/address.rego + tags: + - PII`, relRulPath, + ) + err = os.WriteFile(labelsYamlFile.Name(), []byte(labelsYaml), os.FileMode(0755)) + require.NoError(t, err) + + // Create the rule rego file. + ruleFile, err := os.Create(filepath.Join(ruleDir, "address.rego")) + defer func() { _ = ruleFile.Close() }() + require.NoError(t, err) + err = os.WriteFile(ruleFile.Name(), []byte("package foo"), os.FileMode(0755)) + require.NoError(t, err) + + want := []Label{ + { + Name: "ADDRESS", + Description: "Address", + Tags: []string{"PII"}, + }, + } + got, err := GetCustomLabels(labelsYamlFile.Name()) + require.NoError(t, err) + require.Len(t, got, len(want)) + for i := range got { + require.Equal(t, want[i].Name, got[i].Name) + require.Equal(t, want[i].Description, got[i].Description) + require.ElementsMatch(t, want[i].Tags, got[i].Tags) + // We don't care about the actual rule content, just that it was loaded. + require.NotNil(t, got[i].ClassificationRule) + } +} + +func TestGetCustomLabels_AbsoluteRulePath(t *testing.T) { + labelsDir := t.TempDir() + ruleDir := t.TempDir() + + // Create the labels YAML file. + labelsYamlFile, err := os.CreateTemp(labelsDir, "labels.yaml") + defer func() { _ = labelsYamlFile.Close() }() + require.NoError(t, err) + labelsYaml := fmt.Sprintf( + `ADDRESS: + description: Address + rule: %s/address.rego + tags: + - PII`, ruleDir, + ) + err = os.WriteFile(labelsYamlFile.Name(), []byte(labelsYaml), os.FileMode(0755)) + require.NoError(t, err) + + // Create the rule rego file. + ruleFile, err := os.Create(filepath.Join(ruleDir, "address.rego")) + defer func() { _ = ruleFile.Close() }() + require.NoError(t, err) + err = os.WriteFile(ruleFile.Name(), []byte("package foo"), os.FileMode(0755)) + require.NoError(t, err) + + want := []Label{ + { + Name: "ADDRESS", + Description: "Address", + Tags: []string{"PII"}, + }, + } + got, err := GetCustomLabels(labelsYamlFile.Name()) + require.NoError(t, err) + require.Len(t, got, len(want)) + for i := range got { + require.Equal(t, want[i].Name, got[i].Name) + require.Equal(t, want[i].Description, got[i].Description) + require.ElementsMatch(t, want[i].Tags, got[i].Tags) + // We don't care about the actual rule content, just that it was loaded. + require.NotNil(t, got[i].ClassificationRule) + } +} diff --git a/sql/scanner.go b/sql/scanner.go index 7c98d45..4201a8a 100644 --- a/sql/scanner.go +++ b/sql/scanner.go @@ -62,13 +62,16 @@ func NewScanner(ctx context.Context, cfg ScannerConfig) (*Scanner, error) { lbls, err = classification.GetCustomLabels(cfg.LabelsYamlFilename) } if err != nil { - errMsg := fmt.Sprintf("error(s) loading data labels") + errMsg := "error(s) loading data labels" // This error means that some labels weren't loaded due to having // invalid classification rules. We only log a warning in this case, // since we still want to proceed with the labels that were // successfully loaded. var errs classification.InvalidLabelsError if errors.As(err, &errs) { + if len(lbls) == 0 { + return nil, fmt.Errorf("%s; no labels were loaded: %w", errMsg, err) + } log.WithError(errs).Warnf("%s: some labels were not loaded", errMsg) } else { return nil, fmt.Errorf("%s: %w", errMsg, err)