diff --git a/internal/examples/supervisor/supervisor/commander/commander.go b/internal/examples/supervisor/supervisor/commander/commander.go index bd2f1f69..13fd7333 100644 --- a/internal/examples/supervisor/supervisor/commander/commander.go +++ b/internal/examples/supervisor/supervisor/commander/commander.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/exec" + "sync" "sync/atomic" "syscall" "time" @@ -23,7 +24,11 @@ type Commander struct { cmd *exec.Cmd doneCh chan struct{} waitCh chan struct{} - running int64 + running atomic.Bool + + // True when stopping is in progress. + isStoppingFlag bool + isStoppingMutex sync.RWMutex } func NewCommander(logger types.Logger, cfg *config.Agent, args ...string) (*Commander, error) { @@ -41,6 +46,10 @@ func NewCommander(logger types.Logger, cfg *config.Agent, args ...string) (*Comm // Start the Agent and begin watching the process. // Agent's stdout and stderr are written to a file. func (c *Commander) Start(ctx context.Context) error { + if c.IsStopping() { + return nil + } + c.logger.Debugf(ctx, "Starting agent %s", c.cfg.Executable) logFilePath := "agent.log" @@ -63,7 +72,7 @@ func (c *Commander) Start(ctx context.Context) error { } c.logger.Debugf(ctx, "Agent process started, PID=%d", c.cmd.Process.Pid) - atomic.StoreInt64(&c.running, 1) + c.running.Store(true) go c.watch() @@ -83,7 +92,7 @@ func (c *Commander) Restart(ctx context.Context) error { func (c *Commander) watch() { c.cmd.Wait() c.doneCh <- struct{}{} - atomic.StoreInt64(&c.running, 0) + c.running.Store(false) close(c.waitCh) } @@ -94,7 +103,7 @@ func (c *Commander) Done() <-chan struct{} { // Pid returns Agent process PID if it is started or 0 if it is not. func (c *Commander) Pid() int { - if c.cmd == nil || c.cmd.Process == nil { + if !c.IsRunning() { return 0 } return c.cmd.Process.Pid @@ -102,25 +111,28 @@ func (c *Commander) Pid() int { // ExitCode returns Agent process exit code if it exited or 0 if it is not. func (c *Commander) ExitCode() int { - if c.cmd == nil || c.cmd.ProcessState == nil { + if c.IsRunning() { return 0 } return c.cmd.ProcessState.ExitCode() } func (c *Commander) IsRunning() bool { - return atomic.LoadInt64(&c.running) != 0 + return c.running.Load() } // Stop the Agent process. Sends SIGTERM to the process and wait for up 10 seconds // and if the process does not finish kills it forcedly by sending SIGKILL. // Returns after the process is terminated. func (c *Commander) Stop(ctx context.Context) error { - if c.cmd == nil || c.cmd.Process == nil { + c.isStoppingMutex.Lock() + c.isStoppingFlag = true + c.isStoppingMutex.Unlock() + + if !c.IsRunning() { // Not started, nothing to do. return nil } - c.logger.Debugf(ctx, "Stopping agent process, PID=%v", c.cmd.Process.Pid) // Gracefully signal process to stop. @@ -159,10 +171,17 @@ func (c *Commander) Stop(ctx context.Context) error { // Wait for process to terminate <-c.waitCh - atomic.StoreInt64(&c.running, 0) + c.running.Store(false) // Let goroutine know process is finished. close(finished) return innerErr } + +// IsStopping returns true if Stop() was called. +func (c *Commander) IsStopping() bool { + c.isStoppingMutex.RLock() + defer c.isStoppingMutex.RUnlock() + return c.isStoppingFlag +} diff --git a/internal/examples/supervisor/supervisor/supervisor_test.go b/internal/examples/supervisor/supervisor/supervisor_test.go index 45687201..fbadeea7 100644 --- a/internal/examples/supervisor/supervisor/supervisor_test.go +++ b/internal/examples/supervisor/supervisor/supervisor_test.go @@ -4,7 +4,8 @@ import ( "fmt" "os" "testing" - + "time" + "github.com/stretchr/testify/assert" "github.com/open-telemetry/opamp-go/internal" @@ -62,3 +63,33 @@ agent: supervisor.Shutdown() } + +func TestShutdownRaceCondition(t *testing.T) { + tmpDir := changeCurrentDir(t) + os.WriteFile("supervisor.yaml", []byte(fmt.Sprintf(` +server: + endpoint: ws://127.0.0.1:4320/v1/opamp +agent: + executable: %s/dummy_agent.sh`, tmpDir)), 0644) + + os.WriteFile("dummy_agent.sh", []byte("#!/bin/sh\nsleep 9999\n"), 0755) + + startOpampServer(t) + + // There's no great way to ensure Shutdown gets called before Start. + // The DelayLogger ensures some delay before the goroutine gets started. + var supervisor *Supervisor + var err error + supervisor, err = NewSupervisor(&internal.DelayLogger{}) + supervisor.Shutdown() + supervisor.hasNewConfig <- struct{}{} + + assert.NoError(t, err) + + // The Shutdown method has been called before the runAgentProcess goroutine + // gets started and has a chance to load a new process. Make sure no PID + // has been launched. + assert.Never(t, func() bool { + return supervisor.commander.Pid() != 0 + }, 2*time.Second, 10*time.Millisecond) +} diff --git a/internal/noplogger.go b/internal/noplogger.go index a2b2ea27..9807e6e6 100644 --- a/internal/noplogger.go +++ b/internal/noplogger.go @@ -2,6 +2,7 @@ package internal import ( "context" + "time" "github.com/open-telemetry/opamp-go/client/types" ) @@ -12,3 +13,12 @@ type NopLogger struct{} func (l *NopLogger) Debugf(ctx context.Context, format string, v ...interface{}) {} func (l *NopLogger) Errorf(ctx context.Context, format string, v ...interface{}) {} + +type DelayLogger struct{} + +func (l *DelayLogger) Debugf(ctx context.Context, format string, v ...interface{}) { + time.Sleep(10 * time.Millisecond) +} +func (l *DelayLogger) Errorf(ctx context.Context, format string, v ...interface{}) { + time.Sleep(10 * time.Millisecond) +}