Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions core/framework/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,39 @@ func List() ([]Info, error) {
}

func Load(idOrFile string) (*Framework, error) {
if info, err := os.Stat(idOrFile); err == nil {
if isLikelyPath(idOrFile) {
info, err := os.Stat(idOrFile)
if err != nil {
return nil, fmt.Errorf("load framework file %s: %w", idOrFile, err)
}
if info.IsDir() {
if isLikelyPath(idOrFile) {
return nil, fmt.Errorf("load framework file %s: path is a directory", idOrFile)
}
} else {
return LoadFile(idOrFile)
return nil, fmt.Errorf("load framework file %s: path is a directory", idOrFile)
}
} else if isLikelyPath(idOrFile) {
return LoadFile(idOrFile)
}

embedded, embeddedErr := loadEmbedded(idOrFile)
if embeddedErr == nil {
return embedded, nil
}

info, err := os.Stat(idOrFile)
if err == nil {
if info.IsDir() {
return nil, fmt.Errorf("load framework file %s: path is a directory", idOrFile)
}
return LoadFile(idOrFile)
}
if !os.IsNotExist(err) {
return nil, fmt.Errorf("load framework file %s: %w", idOrFile, err)
}
return nil, embeddedErr
}

func loadEmbedded(idOrFile string) (*Framework, error) {
name := idOrFile
if !strings.HasSuffix(name, ".yaml") {
name = name + ".yaml"
name += ".yaml"
}
raw, err := frameworkFS.ReadFile(name)
if err != nil {
Expand All @@ -91,6 +109,7 @@ func Load(idOrFile string) (*Framework, error) {
}

func LoadFile(path string) (*Framework, error) {
// #nosec G304 -- path is intentionally caller-provided for runtime custom framework loading.
raw, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("load framework file %s: %w", path, err)
Expand Down
43 changes: 43 additions & 0 deletions core/framework/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,49 @@ func TestLoadMissingFilesystemPath(t *testing.T) {
require.ErrorContains(t, err, "load framework file")
}

func TestLoadPrefersEmbeddedOverLocalCollision(t *testing.T) {
dir := t.TempDir()
t.Chdir(dir)
require.NoError(t, os.WriteFile("eu-ai-act.yaml", []byte("not: valid: yaml: ["), 0o644))

f, err := Load("eu-ai-act.yaml")
require.NoError(t, err)
require.Equal(t, "eu-ai-act", f.Framework.ID)
require.Equal(t, "2024-final", f.Framework.Version)

list, err := List()
require.NoError(t, err)
var found bool
for _, info := range list {
if info.ID == "eu-ai-act" {
found = true
require.Equal(t, 3, info.ControlCount)
}
}
require.True(t, found)
}

func TestLoadFilenameFallbackForCustomFramework(t *testing.T) {
dir := t.TempDir()
t.Chdir(dir)
require.NoError(t, os.WriteFile("custom-framework.yaml", []byte(`
framework:
id: custom-framework
version: "1"
title: Custom Framework
controls:
- id: custom-control
title: Custom Control
required_record_types: [decision]
required_fields: [record_id]
minimum_frequency: continuous
`), 0o644))

f, err := Load("custom-framework.yaml")
require.NoError(t, err)
require.Equal(t, "custom-framework", f.Framework.ID)
}

func TestValidateControls(t *testing.T) {
valid := []Control{
{
Expand Down
Loading