diff --git a/core/framework/framework.go b/core/framework/framework.go index 8444320..a8849c5 100644 --- a/core/framework/framework.go +++ b/core/framework/framework.go @@ -67,21 +67,39 @@ func List() ([]Info, error) { } func Load(idOrFile string) (*Framework, error) { - if info, err := os.Stat(idOrFile); err == nil { + if isLikelyPath(idOrFile) { + info, err := os.Stat(idOrFile) + if err != nil { + return nil, fmt.Errorf("load framework file %s: %w", idOrFile, err) + } if info.IsDir() { - if isLikelyPath(idOrFile) { - return nil, fmt.Errorf("load framework file %s: path is a directory", idOrFile) - } - } else { - return LoadFile(idOrFile) + return nil, fmt.Errorf("load framework file %s: path is a directory", idOrFile) } - } else if isLikelyPath(idOrFile) { + return LoadFile(idOrFile) + } + + embedded, embeddedErr := loadEmbedded(idOrFile) + if embeddedErr == nil { + return embedded, nil + } + + info, err := os.Stat(idOrFile) + if err == nil { + if info.IsDir() { + return nil, fmt.Errorf("load framework file %s: path is a directory", idOrFile) + } + return LoadFile(idOrFile) + } + if !os.IsNotExist(err) { return nil, fmt.Errorf("load framework file %s: %w", idOrFile, err) } + return nil, embeddedErr +} +func loadEmbedded(idOrFile string) (*Framework, error) { name := idOrFile if !strings.HasSuffix(name, ".yaml") { - name = name + ".yaml" + name += ".yaml" } raw, err := frameworkFS.ReadFile(name) if err != nil { @@ -91,6 +109,7 @@ func Load(idOrFile string) (*Framework, error) { } func LoadFile(path string) (*Framework, error) { + // #nosec G304 -- path is intentionally caller-provided for runtime custom framework loading. raw, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("load framework file %s: %w", path, err) diff --git a/core/framework/framework_test.go b/core/framework/framework_test.go index 04e755a..dc271d5 100644 --- a/core/framework/framework_test.go +++ b/core/framework/framework_test.go @@ -57,6 +57,49 @@ func TestLoadMissingFilesystemPath(t *testing.T) { require.ErrorContains(t, err, "load framework file") } +func TestLoadPrefersEmbeddedOverLocalCollision(t *testing.T) { + dir := t.TempDir() + t.Chdir(dir) + require.NoError(t, os.WriteFile("eu-ai-act.yaml", []byte("not: valid: yaml: ["), 0o644)) + + f, err := Load("eu-ai-act.yaml") + require.NoError(t, err) + require.Equal(t, "eu-ai-act", f.Framework.ID) + require.Equal(t, "2024-final", f.Framework.Version) + + list, err := List() + require.NoError(t, err) + var found bool + for _, info := range list { + if info.ID == "eu-ai-act" { + found = true + require.Equal(t, 3, info.ControlCount) + } + } + require.True(t, found) +} + +func TestLoadFilenameFallbackForCustomFramework(t *testing.T) { + dir := t.TempDir() + t.Chdir(dir) + require.NoError(t, os.WriteFile("custom-framework.yaml", []byte(` +framework: + id: custom-framework + version: "1" + title: Custom Framework +controls: + - id: custom-control + title: Custom Control + required_record_types: [decision] + required_fields: [record_id] + minimum_frequency: continuous +`), 0o644)) + + f, err := Load("custom-framework.yaml") + require.NoError(t, err) + require.Equal(t, "custom-framework", f.Framework.ID) +} + func TestValidateControls(t *testing.T) { valid := []Control{ {