diff --git a/internal/executer/executer.go b/internal/executer/executer.go index 6c05f1d..97521bf 100644 --- a/internal/executer/executer.go +++ b/internal/executer/executer.go @@ -51,9 +51,14 @@ func New(name string) Executer { // Execute operates on an executable binary and supports context. func (c *executable) Execute(ctx context.Context, input io.Reader, action string) ([]byte, error) { - traceHook := trace.GetTraceHookFromContext(ctx) - traceHook(c.name, action) + trace := trace.GetTraceHookFromContext(ctx) + if trace != nil { + trace.ExecuteStart(c.name, action) + } cmd := exec.CommandContext(ctx, c.name, action) + if trace != nil { + trace.ExecuteStart(c.name, action) + } cmd.Stdin = input cmd.Stderr = os.Stderr output, err := cmd.Output() diff --git a/native_store_test.go b/native_store_test.go index 55d7172..d861b57 100644 --- a/native_store_test.go +++ b/native_store_test.go @@ -31,7 +31,7 @@ import ( const ( basicAuthHost = "localhost:2333" - bearerAuthHost = "localhost:6666" + bearerAuthHost = "localhost:666" exeErrorHost = "localhost:500/exeError" jsonErrorHost = "localhost:500/jsonError" noCredentialsHost = "localhost:404" @@ -75,7 +75,10 @@ func (e *testExecuter) Execute(ctx context.Context, input io.Reader, action stri return []byte("credentials not found"), errCredentialsNotFound case traceHost: traceHook := trace.GetTraceHookFromContext(ctx) - traceHook("testExecuter", "get") + if traceHook != nil { + traceHook.ExecuteStart("testExecuter", "get") + traceHook.ExecuteDone("testExecuter", "get") + } return []byte(`{"Username": "test_username", "Secret": "test_password"}`), nil default: return []byte("program failed"), errCommandExited @@ -91,7 +94,10 @@ func (e *testExecuter) Execute(ctx context.Context, input io.Reader, action stri return nil, nil case traceHost: traceHook := trace.GetTraceHookFromContext(ctx) - traceHook("testExecuter", "store") + if traceHook != nil { + traceHook.ExecuteStart("testExecuter", "store") + traceHook.ExecuteDone("testExecuter", "store") + } return nil, nil default: return []byte("program failed"), errCommandExited @@ -102,7 +108,10 @@ func (e *testExecuter) Execute(ctx context.Context, input io.Reader, action stri return nil, nil case traceHost: traceHook := trace.GetTraceHookFromContext(ctx) - traceHook("testExecuter", "erase") + if traceHook != nil { + traceHook.ExecuteStart("testExecuter", "erase") + traceHook.ExecuteDone("testExecuter", "erase") + } return nil, nil default: return []byte("program failed"), errCommandExited @@ -208,17 +217,25 @@ func TestNativeStore_trace(t *testing.T) { } // create a trace hook that writes to buffer buffer := bytes.Buffer{} - traceHook := func(executableName string, action string) { - buffer.WriteString(fmt.Sprintf("test trace, running executable %s with action %s", executableName, action)) + traceHook := &trace.ExecutableTrace{ + ExecuteStart: func(executableName string, action string) { + buffer.WriteString(fmt.Sprintf("test trace, start the execution of executable %s with action %s ", executableName, action)) + }, + ExecuteDone: func(executableName string, action string) { + buffer.WriteString(fmt.Sprintf("test trace, completed the execution of executable %s with action %s", executableName, action)) + }, + } + ctx, err := trace.WithTraceHook(context.Background(), traceHook) + if err != nil { + t.Errorf("empty trace: %v", err) } - ctx := trace.NewContextWithTraceHook(context.Background(), traceHook) // Test ns.Put trace - err := ns.Put(ctx, traceHost, auth.Credential{Username: testUsername, Password: testPassword}) + err = ns.Put(ctx, traceHost, auth.Credential{Username: testUsername, Password: testPassword}) if err != nil { t.Fatalf("trace test ns.Put fails: %v", err) } bufferContent := buffer.String() - if bufferContent != "test trace, running executable testExecuter with action store" { + if bufferContent != "test trace, start the execution of executable testExecuter with action store test trace, completed the execution of executable testExecuter with action store" { t.Fatalf("incorrect buffer content: %s", bufferContent) } buffer.Reset() @@ -228,7 +245,7 @@ func TestNativeStore_trace(t *testing.T) { t.Fatalf("trace test ns.Get fails: %v", err) } bufferContent = buffer.String() - if bufferContent != "test trace, running executable testExecuter with action get" { + if bufferContent != "test trace, start the execution of executable testExecuter with action get test trace, completed the execution of executable testExecuter with action get" { t.Fatalf("incorrect buffer content: %s", bufferContent) } buffer.Reset() @@ -238,7 +255,35 @@ func TestNativeStore_trace(t *testing.T) { t.Fatalf("trace test ns.Delete fails: %v", err) } bufferContent = buffer.String() - if bufferContent != "test trace, running executable testExecuter with action erase" { + if bufferContent != "test trace, start the execution of executable testExecuter with action erase test trace, completed the execution of executable testExecuter with action erase" { t.Fatalf("incorrect buffer content: %s", bufferContent) } } + +// This test ensures that a nil trace will not cause an error. +func TestNativeStore_noTrace(t *testing.T) { + ns := &nativeStore{ + &testExecuter{}, + } + // Put + err := ns.Put(context.Background(), traceHost, auth.Credential{Username: testUsername, Password: testPassword}) + if err != nil { + t.Fatalf("basic auth test ns.Put fails: %v", err) + } + // Get + cred, err := ns.Get(context.Background(), traceHost) + if err != nil { + t.Fatalf("basic auth test ns.Get fails: %v", err) + } + if cred.Username != testUsername { + t.Fatal("incorrect username") + } + if cred.Password != testPassword { + t.Fatal("incorrect password") + } + // Delete + err = ns.Delete(context.Background(), traceHost) + if err != nil { + t.Fatalf("basic auth test ns.Delete fails: %v", err) + } +} diff --git a/trace/trace.go b/trace/trace.go index a8b13dd..ccc3e68 100644 --- a/trace/trace.go +++ b/trace/trace.go @@ -1,24 +1,51 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package trace -import "context" +import ( + "context" + "errors" +) -// executableTraceHookKey is a value key used to retrieve the trace hook function from Context. +// executableTraceHookKey is a value key used to retrieve the ExecutableTrace +// from Context. type executableTraceHookKey struct{} -// GetTraceHookFromContext returns the trace hook function associated with the -// provided context. If none, it returns nil. -func GetTraceHookFromContext(ctx context.Context) func(executableName string, action string) { - trace, _ := ctx.Value(executableTraceHookKey{}).(func(string, string)) +// ExecutableTrace is a set of hooks used to trace the execution of binary +// executables. +type ExecutableTrace struct { + // ExecuteStart is called before the execution of an executable. + ExecuteStart func(executableName string, action string) + // ExecuteEnd is called after the execution of an executable completes. + ExecuteDone func(executableName string, action string) +} + +// GetTraceHookFromContext returns the ExecutableTrace associated with the context. +// If none, it returns nil. +func GetTraceHookFromContext(ctx context.Context) *ExecutableTrace { + trace, _ := ctx.Value(executableTraceHookKey{}).(*ExecutableTrace) return trace } -// NewContextWithTraceHook takes a Context and a trace hook function of type -// func(executableName string, action string), and returns a new Context with -// the hook added as a Value. -func NewContextWithTraceHook(ctx context.Context, hook func(executableName string, action string)) context.Context { - if hook == nil { - panic("nil trace") +// WithTraceHook takes a Context and an ExecutableTrace, and returns a Context with +// the ExecutableTrace added as a Value. +func WithTraceHook(ctx context.Context, trace *ExecutableTrace) (context.Context, error) { + if trace == nil { + return nil, errors.New("empty ExecutableTrace") } - ctx = context.WithValue(ctx, executableTraceHookKey{}, hook) - return ctx + ctx = context.WithValue(ctx, executableTraceHookKey{}, trace) + return ctx, nil }