From 41255c2f89ba68a19d463dc821c4d1aa286ecdf8 Mon Sep 17 00:00:00 2001 From: Robert Terhaar Date: Sat, 7 Feb 2026 14:24:54 -0500 Subject: [PATCH 1/3] make Stop() blocking, simplify FSM broadcast --- internal/finitestate/machine.go | 79 ++++----- internal/finitestate/machine_test.go | 8 +- runnables/composite/runner.go | 25 +-- runnables/composite/runner_test.go | 165 +++++++++++++------ runnables/httpcluster/runner.go | 23 ++- runnables/httpserver/reload_test.go | 3 - runnables/httpserver/runner.go | 25 ++- runnables/httpserver/runner_shutdown_test.go | 44 +++-- runnables/httpserver/runner_test.go | 3 - runnables/httpserver/state_mocked_test.go | 3 - runnables/httpserver/state_test.go | 26 +-- supervisor/lifecycle/startstop.go | 54 ++++++ supervisor/lifecycle/startstop_test.go | 147 +++++++++++++++++ 13 files changed, 414 insertions(+), 191 deletions(-) create mode 100644 supervisor/lifecycle/startstop.go create mode 100644 supervisor/lifecycle/startstop_test.go diff --git a/internal/finitestate/machine.go b/internal/finitestate/machine.go index 7ea477c..ce2bbd8 100644 --- a/internal/finitestate/machine.go +++ b/internal/finitestate/machine.go @@ -7,7 +7,6 @@ import ( "github.com/robbyt/go-fsm/v2" "github.com/robbyt/go-fsm/v2/hooks" - "github.com/robbyt/go-fsm/v2/hooks/broadcast" "github.com/robbyt/go-fsm/v2/transitions" ) @@ -25,80 +24,68 @@ const ( // TypicalTransitions is a set of standard transitions for a finite state machine. var TypicalTransitions = transitions.Typical -// Machine is a wrapper around go-fsm v2 that provides the v1 API compatibility. -// It manages both the FSM and broadcast functionality. +// Machine wraps go-fsm v2 to provide a simplified API with broadcast support. type Machine struct { *fsm.Machine - broadcastManager *broadcast.Manager } // GetStateChan returns a channel that emits the state whenever it changes. -// The channel is closed when the provided context is canceled. -// For v1 API compatibility, the current state is sent immediately to the channel. -// A 5-second broadcast timeout is used to prevent slow consumers from blocking state updates. +// The current state is sent immediately. The channel is closed when the +// provided context is canceled. func (s *Machine) GetStateChan(ctx context.Context) <-chan string { - return s.getStateChanInternal(ctx, broadcast.WithTimeout(5*time.Second)) -} - -// getStateChanInternal is a helper that creates a channel and sends the current state to it. -// This maintains v1 API compatibility where GetStateChan immediately sends the current state. -func (s *Machine) getStateChanInternal(ctx context.Context, opts ...broadcast.Option) <-chan string { - wrappedCh := make(chan string, 1) - - userCh, err := s.broadcastManager.GetStateChan(ctx, opts...) - if err != nil { - close(wrappedCh) - return wrappedCh + sourceCh := make(chan string, 1) + if err := s.Machine.GetStateChan(ctx, sourceCh); err != nil { + ch := make(chan string) + close(ch) + return ch } - currentState := s.GetState() - wrappedCh <- currentState + // go-fsm sends the current state synchronously into sourceCh before + // returning. Forward it into outCh now so callers can do non-blocking reads. + outCh := make(chan string, 1) + outCh <- <-sourceCh go func() { - defer close(wrappedCh) - for state := range userCh { - wrappedCh <- state + defer close(outCh) + for { + select { + case state := <-sourceCh: + outCh <- state + case <-ctx.Done(): + // Drain any state already buffered in sourceCh before closing. + select { + case state := <-sourceCh: + outCh <- state + default: + } + return + } } }() - return wrappedCh + return outCh } -// New creates a new finite state machine with the specified logger using "standard" state transitions. -// This function provides compatibility with the v1 API while using v2 under the hood. -func New(handler slog.Handler) (*Machine, error) { +// New creates a new finite state machine with the specified logger and transitions. +func New(handler slog.Handler, t *transitions.Config) (*Machine, error) { registry, err := hooks.NewRegistry( hooks.WithLogHandler(handler), - hooks.WithTransitions(TypicalTransitions), + hooks.WithTransitions(t), ) if err != nil { return nil, err } - broadcastManager := broadcast.NewManager(handler) - - err = registry.RegisterPostTransitionHook(hooks.PostTransitionHookConfig{ - Name: "broadcast", - From: []string{"*"}, - To: []string{"*"}, - Action: broadcastManager.BroadcastHook, - }) - if err != nil { - return nil, err - } - f, err := fsm.New( StatusNew, - TypicalTransitions, + t, fsm.WithLogHandler(handler), fsm.WithCallbackRegistry(registry), + fsm.WithBroadcastTimeout(5*time.Second), ) if err != nil { return nil, err } - return &Machine{ - Machine: f, - broadcastManager: broadcastManager, - }, nil + return &Machine{Machine: f}, nil } diff --git a/internal/finitestate/machine_test.go b/internal/finitestate/machine_test.go index 2d98a95..6a95311 100644 --- a/internal/finitestate/machine_test.go +++ b/internal/finitestate/machine_test.go @@ -16,7 +16,7 @@ func TestNew(t *testing.T) { t.Run("creates new machine with correct initial state", func(t *testing.T) { handler := slog.NewTextHandler(os.Stdout, nil) - machine, err := New(handler) + machine, err := New(handler, TypicalTransitions) require.NoError(t, err) require.NotNil(t, machine) @@ -26,7 +26,7 @@ func TestNew(t *testing.T) { t.Run("uses provided handler", func(t *testing.T) { // Create a test handler handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}) - machine, err := New(handler) + machine, err := New(handler, TypicalTransitions) require.NoError(t, err) require.NotNil(t, machine) @@ -38,7 +38,7 @@ func TestMachineInterface(t *testing.T) { setup := func() *Machine { handler := slog.NewTextHandler(os.Stdout, nil) - m, err := New(handler) + m, err := New(handler, TypicalTransitions) require.NoError(t, err) return m } @@ -203,7 +203,7 @@ func TestGetStateChanWithTimeout(t *testing.T) { setup := func() *Machine { handler := slog.NewTextHandler(os.Stdout, nil) - m, err := New(handler) + m, err := New(handler, TypicalTransitions) require.NoError(t, err) return m } diff --git a/runnables/composite/runner.go b/runnables/composite/runner.go index 4ca1c87..7facc35 100644 --- a/runnables/composite/runner.go +++ b/runnables/composite/runner.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "github.com/robbyt/go-supervisor/internal/finitestate" + "github.com/robbyt/go-supervisor/supervisor/lifecycle" ) // ConfigCallback is the function type signature for the callback used to load initial config, and new config during Reload() @@ -26,6 +27,7 @@ type fsm interface { // as a single unit. It satisfies the Runnable, Reloadable, and Stateable interfaces. type Runner[T runnable] struct { fsm fsm + lc *lifecycle.StartStop configMu sync.Mutex // Only used for getConfig() currentConfig atomic.Pointer[Config[T]] configCallback ConfigCallback[T] @@ -53,6 +55,7 @@ func NewRunner[T runnable]( ) (*Runner[T], error) { logger := slog.Default().WithGroup("composite.Runner") r := &Runner[T]{ + lc: lifecycle.New(), currentConfig: atomic.Pointer[Config[T]]{}, configCallback: configCallback, serverErrors: make(chan error, 1), @@ -75,6 +78,7 @@ func NewRunner[T runnable]( // Create FSM after the optional logger has been configured fsm, err := finitestate.New( r.logger.WithGroup("fsm").Handler(), + finitestate.TypicalTransitions, ) if err != nil { return nil, fmt.Errorf("unable to create fsm: %w", err) @@ -99,7 +103,9 @@ func (r *Runner[T]) Run(ctx context.Context) error { runCtx, runCancel := context.WithCancel(ctx) defer runCancel() - // store the Run context and cancel function in the runner so that Reload() and Stop() can use them later + done := r.lc.Started() + defer done() + r.runnablesMu.Lock() r.ctx = runCtx r.cancel = runCancel @@ -122,10 +128,13 @@ func (r *Runner[T]) Run(ctx context.Context) error { return fmt.Errorf("failed to transition to Running state: %w", err) } - // Wait for context cancellation or errors + // Wait for context cancellation, stop signal, or errors select { case <-runCtx.Done(): r.logger.Debug("Local context canceled") + case <-r.lc.StopCh(): + r.logger.Debug("Stop() called") + runCancel() case err := <-r.serverErrors: r.setStateError() stopErr := r.stopAllRunnables() @@ -153,17 +162,9 @@ func (r *Runner[T]) Run(ctx context.Context) error { return nil } -// Stop will cancel the context, causing all child runnables to stop. +// Stop signals the runner to shut down and blocks until Run() completes. func (r *Runner[T]) Stop() { - r.runnablesMu.Lock() - cancel := r.cancel - r.runnablesMu.Unlock() - - if cancel == nil { - r.logger.Warn("Cancel function is nil, skipping Stop") - return - } - cancel() + r.lc.Stop() } // boot starts all child runnables in the order they're defined. diff --git a/runnables/composite/runner_test.go b/runnables/composite/runner_test.go index edb26a6..378509f 100644 --- a/runnables/composite/runner_test.go +++ b/runnables/composite/runner_test.go @@ -5,6 +5,7 @@ import ( "errors" "log/slog" "os" + "sync" "sync/atomic" "testing" "testing/synctest" @@ -427,76 +428,138 @@ func TestCompositeRunner_Run(t *testing.T) { func TestCompositeRunner_Stop(t *testing.T) { t.Parallel() - // Create reusable mock setup function - setupMocksAndConfig := func() (entries []RunnableEntry[*mocks.Runnable], configFunc func() (*Config[*mocks.Runnable], error)) { - mockRunnable1 := mocks.NewMockRunnable() - mockRunnable2 := mocks.NewMockRunnable() + t.Run("stop blocks until run completes", func(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + mock1 := mocks.NewMockRunnable() + mock1.On("String").Return("r1") + mock1.On("Run", mock.Anything).Run(func(args mock.Arguments) { + <-args.Get(0).(context.Context).Done() + }).Return(nil) + mock1.On("Stop").Once() + + entries := []RunnableEntry[*mocks.Runnable]{ + {Runnable: mock1}, + } - entries = []RunnableEntry[*mocks.Runnable]{ - {Runnable: mockRunnable1, Config: nil}, - {Runnable: mockRunnable2, Config: nil}, - } + configCallback := func() (*Config[*mocks.Runnable], error) { + return NewConfig("test", entries) + } - // Create config callback - configFunc = func() (*Config[*mocks.Runnable], error) { - return NewConfig("test", entries) - } + runner, err := NewRunner(configCallback) + require.NoError(t, err) - return entries, configFunc - } + errCh := make(chan error, 1) + runReturned := atomic.Bool{} + go func() { + errCh <- runner.Run(t.Context()) + runReturned.Store(true) + }() + + time.Sleep(10 * time.Millisecond) + synctest.Wait() + assert.Equal(t, finitestate.StatusRunning, runner.GetState()) + + runner.Stop() + + assert.True(t, runReturned.Load(), "Run should have returned before Stop unblocked") + assert.Equal(t, finitestate.StatusStopped, runner.GetState()) + require.NoError(t, <-errCh) + + mock1.AssertExpectations(t) + }) + }) - t.Run("stop from running state", func(t *testing.T) { + t.Run("stop before run", func(t *testing.T) { t.Parallel() + synctest.Test(t, func(t *testing.T) { + mock1 := mocks.NewMockRunnable() + mock1.On("String").Return("r1") + mock1.On("Run", mock.Anything).Run(func(args mock.Arguments) { + <-args.Get(0).(context.Context).Done() + }).Return(nil) + mock1.On("Stop").Once() + + entries := []RunnableEntry[*mocks.Runnable]{ + {Runnable: mock1}, + } - // Setup mock runnables and config - _, configCallback := setupMocksAndConfig() + configCallback := func() (*Config[*mocks.Runnable], error) { + return NewConfig("test", entries) + } - // Create runner - runner, err := NewRunner(configCallback) - require.NoError(t, err) + runner, err := NewRunner(configCallback) + require.NoError(t, err) - // Set up cancel function as Run() would - ctx, cancel := context.WithCancel(t.Context()) - runner.runnablesMu.Lock() - runner.ctx = ctx - runner.cancel = cancel - runner.runnablesMu.Unlock() + stopReturned := atomic.Bool{} + go func() { + runner.Stop() + stopReturned.Store(true) + }() - err = runner.fsm.SetState(finitestate.StatusRunning) - require.NoError(t, err) + time.Sleep(10 * time.Millisecond) + synctest.Wait() - // Call Stop - should just cancel the context, not change state - runner.Stop() + assert.False(t, stopReturned.Load(), "Stop should block until Run starts and completes") - // Verify state did not change (Stop only cancels context) - assert.Equal(t, finitestate.StatusRunning, runner.GetState()) + errCh := make(chan error, 1) + go func() { + errCh <- runner.Run(t.Context()) + }() - // Verify context was cancelled - select { - case <-ctx.Done(): - // Good, context was cancelled - default: - t.Error("Context should be cancelled after Stop()") - } + time.Sleep(10 * time.Millisecond) + synctest.Wait() + + assert.True(t, stopReturned.Load(), "Stop should unblock after Run completes") + require.NoError(t, <-errCh) + + mock1.AssertExpectations(t) + }) }) - t.Run("stop from non-running state", func(t *testing.T) { + t.Run("multiple stop calls", func(t *testing.T) { t.Parallel() + synctest.Test(t, func(t *testing.T) { + mock1 := mocks.NewMockRunnable() + mock1.DelayStop = 0 + mock1.On("String").Return("r1") + mock1.On("Run", mock.Anything).Run(func(args mock.Arguments) { + <-args.Get(0).(context.Context).Done() + }).Return(nil) + mock1.On("Stop").Once() + + entries := []RunnableEntry[*mocks.Runnable]{ + {Runnable: mock1}, + } - // Setup mock runnables and config - _, configCallback := setupMocksAndConfig() + configCallback := func() (*Config[*mocks.Runnable], error) { + return NewConfig("test", entries) + } - // Create runner and manually set state to Stopped - runner, err := NewRunner(configCallback) - require.NoError(t, err) - err = runner.fsm.SetState(finitestate.StatusStopped) - require.NoError(t, err) + runner, err := NewRunner(configCallback) + require.NoError(t, err) - // Call Stop - runner.Stop() + errCh := make(chan error, 1) + go func() { + errCh <- runner.Run(t.Context()) + }() - // Verify state did not change - assert.Equal(t, finitestate.StatusStopped, runner.GetState()) + time.Sleep(10 * time.Millisecond) + synctest.Wait() + assert.Equal(t, finitestate.StatusRunning, runner.GetState()) + + var wg sync.WaitGroup + for range 5 { + wg.Go(func() { + runner.Stop() + }) + } + wg.Wait() + + assert.Equal(t, finitestate.StatusStopped, runner.GetState()) + require.NoError(t, <-errCh) + mock1.AssertExpectations(t) + }) }) } diff --git a/runnables/httpcluster/runner.go b/runnables/httpcluster/runner.go index 9ef90d9..b68f11a 100644 --- a/runnables/httpcluster/runner.go +++ b/runnables/httpcluster/runner.go @@ -10,6 +10,7 @@ import ( "github.com/robbyt/go-supervisor/internal/finitestate" "github.com/robbyt/go-supervisor/runnables/httpserver" "github.com/robbyt/go-supervisor/supervisor" + "github.com/robbyt/go-supervisor/supervisor/lifecycle" ) const ( @@ -30,6 +31,7 @@ type fsm interface { // It implements supervisor.Runnable and supervisor.Stateable interfaces. type Runner struct { fsm fsm + lc *lifecycle.StartStop mu sync.RWMutex // runner factory creates the Runnable instances @@ -77,6 +79,7 @@ func defaultRunnerFactory( // NewRunner creates a new HTTP cluster runner with the provided options. func NewRunner(opts ...Option) (*Runner, error) { r := &Runner{ + lc: lifecycle.New(), runnerFactory: defaultRunnerFactory, logger: slog.Default().WithGroup("httpcluster.Runner"), restartDelay: defaultRestartDelay, @@ -99,7 +102,7 @@ func NewRunner(opts ...Option) (*Runner, error) { // Create FSM with the configured logger fsmLogger := r.logger.WithGroup("fsm") - machine, err := finitestate.New(fsmLogger.Handler()) + machine, err := finitestate.New(fsmLogger.Handler(), finitestate.TypicalTransitions) if err != nil { return nil, fmt.Errorf("unable to create fsm: %w", err) } @@ -175,10 +178,13 @@ func (r *Runner) Run(ctx context.Context) error { return fmt.Errorf("failed to transition to booting state: %w", err) } - // Set up local run context, share it to make it accessible to shutdown - r.mu.Lock() runCtx, runCancel := context.WithCancel(ctx) defer runCancel() + + done := r.lc.Started() + defer done() + + r.mu.Lock() r.ctx = runCtx r.cancel = runCancel r.mu.Unlock() @@ -196,6 +202,11 @@ func (r *Runner) Run(ctx context.Context) error { logger.Debug("Run context cancelled, initiating shutdown") return r.shutdown(runCtx) + case <-r.lc.StopCh(): + logger.Debug("Stop() called, initiating shutdown") + runCancel() + return r.shutdown(runCtx) + case newConfigs, ok := <-r.configSiphon: if !ok { logger.Debug("Config siphon closed, initiating shutdown") @@ -212,13 +223,11 @@ func (r *Runner) Run(ctx context.Context) error { } // Stop signals the cluster to stop all servers and shut down. +// It blocks until Run() has completed shutdown. func (r *Runner) Stop() { logger := r.logger.WithGroup("Stop") logger.Debug("Stopping") - - r.mu.Lock() - r.cancel() - r.mu.Unlock() + r.lc.Stop() } // shutdown performs graceful shutdown of all servers. diff --git a/runnables/httpserver/reload_test.go b/runnables/httpserver/reload_test.go index f0dbe2b..0a932c3 100644 --- a/runnables/httpserver/reload_test.go +++ b/runnables/httpserver/reload_test.go @@ -365,9 +365,6 @@ func TestReload(t *testing.T) { WithConfigCallback(newCfgCallback), ) require.NoError(t, err) - t.Cleanup(func() { - updatedServer.Stop() // Clean up the updated server too - }) // Replace the original server with our updated one for the test // We'll keep the original server's FSM state diff --git a/runnables/httpserver/runner.go b/runnables/httpserver/runner.go index 1e0c1c1..79c7af3 100644 --- a/runnables/httpserver/runner.go +++ b/runnables/httpserver/runner.go @@ -14,6 +14,7 @@ import ( "github.com/robbyt/go-supervisor/internal/finitestate" "github.com/robbyt/go-supervisor/supervisor" + "github.com/robbyt/go-supervisor/supervisor/lifecycle" ) // Interface guards verify implementation at compile time @@ -45,6 +46,7 @@ type fsm interface { // interfaces from the supervisor package. type Runner struct { fsm fsm + lc *lifecycle.StartStop mutex sync.RWMutex name string config atomic.Pointer[Config] @@ -67,6 +69,7 @@ func NewRunner(opts ...Option) (*Runner, error) { logger := slog.Default().WithGroup("httpserver.Runner") r := &Runner{ + lc: lifecycle.New(), name: "", config: atomic.Pointer[Config]{}, serverCloseOnce: sync.Once{}, @@ -86,7 +89,7 @@ func NewRunner(opts ...Option) (*Runner, error) { // Create FSM with the configured logger fsmLogger := r.logger.WithGroup("fsm") - machine, err := finitestate.New(fsmLogger.Handler()) + machine, err := finitestate.New(fsmLogger.Handler(), finitestate.TypicalTransitions) if err != nil { return nil, fmt.Errorf("unable to create fsm: %w", err) } @@ -125,7 +128,9 @@ func (r *Runner) Run(ctx context.Context) error { runCtx, runCancel := context.WithCancel(ctx) defer runCancel() - // Store the context and cancel function + done := r.lc.Started() + defer done() + r.mutex.Lock() r.ctx = runCtx r.cancel = runCancel @@ -156,6 +161,9 @@ func (r *Runner) Run(ctx context.Context) error { select { case <-runCtx.Done(): r.logger.Debug("Context canceled") + case <-r.lc.StopCh(): + r.logger.Debug("Stop() called") + runCancel() case err := <-r.serverErrors: r.setStateError() return fmt.Errorf("%w: %w", ErrHttpServer, err) @@ -164,19 +172,10 @@ func (r *Runner) Run(ctx context.Context) error { return r.shutdown(runCtx) } -// Stop signals the HTTP server to shut down by canceling its context. +// Stop signals the HTTP server to shut down and blocks until Run() completes. func (r *Runner) Stop() { r.logger.Debug("Stopping HTTP server") - - r.mutex.RLock() - cancel := r.cancel - r.mutex.RUnlock() - - if cancel == nil { - r.logger.Warn("Cancel function is nil, skipping Stop") - return - } - cancel() + r.lc.Stop() } // serverReadinessProbe verifies the HTTP server is accepting connections by diff --git a/runnables/httpserver/runner_shutdown_test.go b/runnables/httpserver/runner_shutdown_test.go index 32a473d..a359fbc 100644 --- a/runnables/httpserver/runner_shutdown_test.go +++ b/runnables/httpserver/runner_shutdown_test.go @@ -446,22 +446,20 @@ func TestRun_ShutdownDeadlineExceeded(t *testing.T) { } }() - // Wait for handler to start, then initiate shutdown - require.Eventually(t, func() bool { - select { - case <-started: - server.Stop() - return true - default: - return false - } - }, 2*time.Second, 10*time.Millisecond, "Handler did not start in time") + // Wait for handler to start + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("Handler did not start in time") + } - // Measure shutdown time + // Measure shutdown time — Stop() blocks until Run() completes start := time.Now() - err = <-done + server.Stop() elapsed := time.Since(start) + err = <-done + // Verify shutdown behavior with timeout error require.Error(t, err) require.ErrorIs(t, err, ErrGracefulShutdownTimeout, "Expected shutdown timeout error") @@ -528,22 +526,20 @@ func TestRun_ShutdownWithDrainTimeout(t *testing.T) { } }() - // Wait for handler to start, then initiate shutdown - require.Eventually(t, func() bool { - select { - case <-started: - server.Stop() - return true - default: - return false - } - }, 2*time.Second, 10*time.Millisecond, "Handler did not start in time") + // Wait for handler to start + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("Handler did not start in time") + } - // Measure shutdown time + // Measure shutdown time — Stop() blocks until Run() completes start := time.Now() - err = <-done + server.Stop() elapsed := time.Since(start) + err = <-done + // Verify shutdown behavior require.NoError(t, err) require.GreaterOrEqual(t, elapsed.Seconds(), sleepDuration.Seconds(), diff --git a/runnables/httpserver/runner_test.go b/runnables/httpserver/runner_test.go index fea4e9f..feeae0c 100644 --- a/runnables/httpserver/runner_test.go +++ b/runnables/httpserver/runner_test.go @@ -190,9 +190,6 @@ func TestStopServerWhenNotRunning(t *testing.T) { server, listenPort := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, "/", 1*time.Second) t.Logf("Server listening on port %s", listenPort) - t.Cleanup(func() { - server.Stop() - }) err := server.stopServer(context.Background()) require.Error(t, err) diff --git a/runnables/httpserver/state_mocked_test.go b/runnables/httpserver/state_mocked_test.go index 8b64a01..55dacc2 100644 --- a/runnables/httpserver/state_mocked_test.go +++ b/runnables/httpserver/state_mocked_test.go @@ -21,9 +21,6 @@ func TestSetStateError_FullIntegration(t *testing.T) { server, listenPort := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, "/test", 1*time.Second) t.Logf("Server listening on port %s", listenPort) - t.Cleanup(func() { - server.Stop() - }) // Set up initial conditions with FSM in Stopped state err := server.fsm.SetState(finitestate.StatusStopped) diff --git a/runnables/httpserver/state_test.go b/runnables/httpserver/state_test.go index b375b79..9c9aaa5 100644 --- a/runnables/httpserver/state_test.go +++ b/runnables/httpserver/state_test.go @@ -19,9 +19,6 @@ func TestGetState(t *testing.T) { server, listenPort := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, "/", 1*time.Second) t.Logf("Server listening on port %s", listenPort) - t.Cleanup(func() { - server.Stop() - }) // Test initial state assert.Equal(t, finitestate.StatusNew, server.GetState(), "Initial state should be New") @@ -47,9 +44,6 @@ func TestGetStateChan(t *testing.T) { server, listenPort := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, "/", 1*time.Second) t.Logf("Server listening on port %s", listenPort) - t.Cleanup(func() { - server.Stop() - }) // Create a context with timeout for safety ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) @@ -87,9 +81,6 @@ func TestIsRunning(t *testing.T) { server, listenPort := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, "/", 1*time.Second) t.Logf("Server listening on port %s", listenPort) - t.Cleanup(func() { - server.Stop() - }) // Test when state is not running err := server.fsm.SetState(finitestate.StatusNew) @@ -131,9 +122,6 @@ func TestSetStateError(t *testing.T) { server, listenPort := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, "/test", 1*time.Second) t.Logf("Server listening on port %s", listenPort) - t.Cleanup(func() { - server.Stop() - }) // Set a known state that can transition to error err := server.fsm.SetState(finitestate.StatusNew) @@ -149,9 +137,6 @@ func TestSetStateError(t *testing.T) { server, listenPort := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, "/test", 1*time.Second) t.Logf("Server listening on port %s", listenPort) - t.Cleanup(func() { - server.Stop() - }) // Set a state that won't normally transition to error // Force it to be in Running state first @@ -172,14 +157,11 @@ func TestSetStateError(t *testing.T) { server, listenPort := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, "/test", 1*time.Second) t.Logf("Server listening on port %s", listenPort) - t.Cleanup(func() { - server.Stop() - }) // Get typical transitions to use as base logger := slog.Default().WithGroup("testFSM") // Create valid FSM but force it to a specific state - validFSM, err := finitestate.New(logger.Handler()) + validFSM, err := finitestate.New(logger.Handler(), finitestate.TypicalTransitions) require.NoError(t, err) // Force it to be in Stopping state @@ -208,9 +190,6 @@ func TestWaitForState(t *testing.T) { server, listenPort := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, "/", 1*time.Second) t.Logf("Server listening on port %s", listenPort) - t.Cleanup(func() { - server.Stop() - }) // Set the state to Running err := server.fsm.SetState(finitestate.StatusRunning) @@ -238,9 +217,6 @@ func TestGetStateChanWithTimeout(t *testing.T) { server, listenPort := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, "/", 1*time.Second) t.Logf("Server listening on port %s", listenPort) - t.Cleanup(func() { - server.Stop() - }) ctx, cancel := context.WithCancel(t.Context()) defer cancel() diff --git a/supervisor/lifecycle/startstop.go b/supervisor/lifecycle/startstop.go new file mode 100644 index 0000000..a4aabae --- /dev/null +++ b/supervisor/lifecycle/startstop.go @@ -0,0 +1,54 @@ +package lifecycle + +import "sync" + +// StartStop manages the Run/Stop synchronization for a Runnable. +// It ensures Stop() blocks until Run() has completed, handling all +// orderings: stop-before-run, stop-during-run, stop-after-run, +// and multiple concurrent Stop() calls. +type StartStop struct { + mu sync.Mutex + stopOnce sync.Once + startOnce sync.Once + stopCh chan struct{} + startedCh chan struct{} + doneCh chan struct{} +} + +// New creates a new StartStop instance. +func New() *StartStop { + return &StartStop{ + stopCh: make(chan struct{}), + startedCh: make(chan struct{}), + } +} + +// Started is called at the beginning of Run(). It returns a done function +// that must be deferred to signal Run() completion. +func (l *StartStop) Started() (done func()) { + doneCh := make(chan struct{}) + l.mu.Lock() + l.doneCh = doneCh + l.mu.Unlock() + l.startOnce.Do(func() { close(l.startedCh) }) + var doneOnce sync.Once + return func() { doneOnce.Do(func() { close(doneCh) }) } +} + +// Stop signals the Runnable to stop and blocks until Run() completes. +// If Run() has not been called yet, Stop blocks until it starts and finishes. +// Safe to call from multiple goroutines concurrently. +func (l *StartStop) Stop() { + l.stopOnce.Do(func() { close(l.stopCh) }) + <-l.startedCh + l.mu.Lock() + doneCh := l.doneCh + l.mu.Unlock() + <-doneCh +} + +// StopCh returns a channel that is closed when Stop() is called. +// Use this in a select statement within Run() to detect stop signals. +func (l *StartStop) StopCh() <-chan struct{} { + return l.stopCh +} diff --git a/supervisor/lifecycle/startstop_test.go b/supervisor/lifecycle/startstop_test.go new file mode 100644 index 0000000..f8cb51b --- /dev/null +++ b/supervisor/lifecycle/startstop_test.go @@ -0,0 +1,147 @@ +package lifecycle + +import ( + "sync" + "sync/atomic" + "testing" + "testing/synctest" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStartStop_StopBlocksUntilRunCompletes(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + lc := New() + + runReturned := atomic.Bool{} + go func() { + done := lc.Started() + defer done() + <-lc.StopCh() + runReturned.Store(true) + }() + + time.Sleep(time.Second) + synctest.Wait() + + lc.Stop() + + assert.True(t, runReturned.Load(), "Run should have returned before Stop unblocked") + }) +} + +func TestStartStop_StopBeforeRun(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + lc := New() + + stopReturned := atomic.Bool{} + go func() { + lc.Stop() + stopReturned.Store(true) + }() + + time.Sleep(time.Second) + synctest.Wait() + + assert.False(t, stopReturned.Load(), "Stop should block until Run starts and completes") + + done := lc.Started() + done() + + time.Sleep(time.Second) + synctest.Wait() + + assert.True(t, stopReturned.Load(), "Stop should unblock after Run completes") + }) +} + +func TestStartStop_StopAfterRun(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + lc := New() + + done := lc.Started() + done() + + lc.Stop() + }) +} + +func TestStartStop_MultipleConcurrentStops(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + lc := New() + + go func() { + done := lc.Started() + defer done() + <-lc.StopCh() + }() + + time.Sleep(time.Second) + synctest.Wait() + + var wg sync.WaitGroup + for range 5 { + wg.Go(func() { + lc.Stop() + }) + } + wg.Wait() + }) +} + +func TestStartStop_DoubleStartedDoesNotPanic(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + lc := New() + + done1 := lc.Started() + done2 := lc.Started() // must not panic + done1() + done2() + + lc.Stop() + }) +} + +func TestStartStop_DoubleDoneDoesNotPanic(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + lc := New() + + done := lc.Started() + done() + done() // must not panic + + lc.Stop() + }) +} + +func TestStartStop_StopChClosedAfterStop(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + lc := New() + + go func() { + done := lc.Started() + defer done() + <-lc.StopCh() + }() + + time.Sleep(time.Second) + synctest.Wait() + + lc.Stop() + + select { + case <-lc.StopCh(): + // expected + default: + t.Fatal("StopCh should be closed after Stop()") + } + }) +} From f31573f0b716fa5c78ee7e1ca9903052d14c90c3 Mon Sep 17 00:00:00 2001 From: Robert Terhaar Date: Sun, 8 Feb 2026 12:01:05 -0500 Subject: [PATCH 2/3] simplify calls to our finitestate abstraction by creating a new constructor --- internal/finitestate/machine.go | 9 +++++++-- internal/finitestate/machine_test.go | 10 +++++----- runnables/composite/runner.go | 5 +---- runnables/httpcluster/runner.go | 2 +- runnables/httpserver/runner.go | 2 +- runnables/httpserver/state_test.go | 2 +- 6 files changed, 16 insertions(+), 14 deletions(-) diff --git a/internal/finitestate/machine.go b/internal/finitestate/machine.go index ce2bbd8..2523507 100644 --- a/internal/finitestate/machine.go +++ b/internal/finitestate/machine.go @@ -21,8 +21,8 @@ const ( StatusUnknown = transitions.StatusUnknown ) -// TypicalTransitions is a set of standard transitions for a finite state machine. -var TypicalTransitions = transitions.Typical +// typicalTransitions is a set of standard transitions for a finite state machine. +var typicalTransitions = transitions.Typical // Machine wraps go-fsm v2 to provide a simplified API with broadcast support. type Machine struct { @@ -89,3 +89,8 @@ func New(handler slog.Handler, t *transitions.Config) (*Machine, error) { return &Machine{Machine: f}, nil } + +// NewTypicalFSM creates a new finite state machine with standard transitions. +func NewTypicalFSM(handler slog.Handler) (*Machine, error) { + return New(handler, typicalTransitions) +} diff --git a/internal/finitestate/machine_test.go b/internal/finitestate/machine_test.go index 6a95311..f95ff36 100644 --- a/internal/finitestate/machine_test.go +++ b/internal/finitestate/machine_test.go @@ -16,7 +16,7 @@ func TestNew(t *testing.T) { t.Run("creates new machine with correct initial state", func(t *testing.T) { handler := slog.NewTextHandler(os.Stdout, nil) - machine, err := New(handler, TypicalTransitions) + machine, err := New(handler, typicalTransitions) require.NoError(t, err) require.NotNil(t, machine) @@ -26,7 +26,7 @@ func TestNew(t *testing.T) { t.Run("uses provided handler", func(t *testing.T) { // Create a test handler handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}) - machine, err := New(handler, TypicalTransitions) + machine, err := New(handler, typicalTransitions) require.NoError(t, err) require.NotNil(t, machine) @@ -38,7 +38,7 @@ func TestMachineInterface(t *testing.T) { setup := func() *Machine { handler := slog.NewTextHandler(os.Stdout, nil) - m, err := New(handler, TypicalTransitions) + m, err := New(handler, typicalTransitions) require.NoError(t, err) return m } @@ -183,7 +183,7 @@ func TestTypicalTransitions(t *testing.T) { t.Parallel() t.Run("verify TypicalTransitions is not nil", func(t *testing.T) { - assert.NotNil(t, TypicalTransitions) + assert.NotNil(t, typicalTransitions) }) t.Run("verify status constants are defined", func(t *testing.T) { @@ -203,7 +203,7 @@ func TestGetStateChanWithTimeout(t *testing.T) { setup := func() *Machine { handler := slog.NewTextHandler(os.Stdout, nil) - m, err := New(handler, TypicalTransitions) + m, err := New(handler, typicalTransitions) require.NoError(t, err) return m } diff --git a/runnables/composite/runner.go b/runnables/composite/runner.go index 7facc35..f25de7d 100644 --- a/runnables/composite/runner.go +++ b/runnables/composite/runner.go @@ -76,10 +76,7 @@ func NewRunner[T runnable]( } // Create FSM after the optional logger has been configured - fsm, err := finitestate.New( - r.logger.WithGroup("fsm").Handler(), - finitestate.TypicalTransitions, - ) + fsm, err := finitestate.NewTypicalFSM(r.logger.WithGroup("fsm").Handler()) if err != nil { return nil, fmt.Errorf("unable to create fsm: %w", err) } diff --git a/runnables/httpcluster/runner.go b/runnables/httpcluster/runner.go index b68f11a..ccee38b 100644 --- a/runnables/httpcluster/runner.go +++ b/runnables/httpcluster/runner.go @@ -102,7 +102,7 @@ func NewRunner(opts ...Option) (*Runner, error) { // Create FSM with the configured logger fsmLogger := r.logger.WithGroup("fsm") - machine, err := finitestate.New(fsmLogger.Handler(), finitestate.TypicalTransitions) + machine, err := finitestate.NewTypicalFSM(fsmLogger.Handler()) if err != nil { return nil, fmt.Errorf("unable to create fsm: %w", err) } diff --git a/runnables/httpserver/runner.go b/runnables/httpserver/runner.go index 79c7af3..4a6a8e3 100644 --- a/runnables/httpserver/runner.go +++ b/runnables/httpserver/runner.go @@ -89,7 +89,7 @@ func NewRunner(opts ...Option) (*Runner, error) { // Create FSM with the configured logger fsmLogger := r.logger.WithGroup("fsm") - machine, err := finitestate.New(fsmLogger.Handler(), finitestate.TypicalTransitions) + machine, err := finitestate.NewTypicalFSM(fsmLogger.Handler()) if err != nil { return nil, fmt.Errorf("unable to create fsm: %w", err) } diff --git a/runnables/httpserver/state_test.go b/runnables/httpserver/state_test.go index 9c9aaa5..83881ea 100644 --- a/runnables/httpserver/state_test.go +++ b/runnables/httpserver/state_test.go @@ -161,7 +161,7 @@ func TestSetStateError(t *testing.T) { // Get typical transitions to use as base logger := slog.Default().WithGroup("testFSM") // Create valid FSM but force it to a specific state - validFSM, err := finitestate.New(logger.Handler(), finitestate.TypicalTransitions) + validFSM, err := finitestate.NewTypicalFSM(logger.Handler()) require.NoError(t, err) // Force it to be in Stopping state From f78993905ed6028906d95066886a7e2f1aa807a2 Mon Sep 17 00:00:00 2001 From: Robert Terhaar Date: Mon, 9 Feb 2026 12:48:00 -0500 Subject: [PATCH 3/3] switch back to the broadcast manager for go-fsm --- internal/finitestate/machine.go | 64 +++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/internal/finitestate/machine.go b/internal/finitestate/machine.go index 2523507..156751c 100644 --- a/internal/finitestate/machine.go +++ b/internal/finitestate/machine.go @@ -7,6 +7,7 @@ import ( "github.com/robbyt/go-fsm/v2" "github.com/robbyt/go-fsm/v2/hooks" + "github.com/robbyt/go-fsm/v2/hooks/broadcast" "github.com/robbyt/go-fsm/v2/transitions" ) @@ -27,43 +28,40 @@ var typicalTransitions = transitions.Typical // Machine wraps go-fsm v2 to provide a simplified API with broadcast support. type Machine struct { *fsm.Machine + broadcastManager *broadcast.Manager } // GetStateChan returns a channel that emits the state whenever it changes. // The current state is sent immediately. The channel is closed when the // provided context is canceled. +// A 5-second broadcast timeout prevents slow consumers from blocking state updates. func (s *Machine) GetStateChan(ctx context.Context) <-chan string { - sourceCh := make(chan string, 1) - if err := s.Machine.GetStateChan(ctx, sourceCh); err != nil { - ch := make(chan string) - close(ch) - return ch + return s.getStateChanInternal(ctx, broadcast.WithTimeout(5*time.Second)) +} + +// getStateChanInternal subscribes to state changes via the broadcast manager +// and sends the current state immediately for compatibility with callers that +// expect the initial state on the channel. +func (s *Machine) getStateChanInternal(ctx context.Context, opts ...broadcast.Option) <-chan string { + wrappedCh := make(chan string, 1) + + userCh, err := s.broadcastManager.GetStateChan(ctx, opts...) + if err != nil { + close(wrappedCh) + return wrappedCh } - // go-fsm sends the current state synchronously into sourceCh before - // returning. Forward it into outCh now so callers can do non-blocking reads. - outCh := make(chan string, 1) - outCh <- <-sourceCh + currentState := s.GetState() + wrappedCh <- currentState go func() { - defer close(outCh) - for { - select { - case state := <-sourceCh: - outCh <- state - case <-ctx.Done(): - // Drain any state already buffered in sourceCh before closing. - select { - case state := <-sourceCh: - outCh <- state - default: - } - return - } + defer close(wrappedCh) + for state := range userCh { + wrappedCh <- state } }() - return outCh + return wrappedCh } // New creates a new finite state machine with the specified logger and transitions. @@ -76,18 +74,32 @@ func New(handler slog.Handler, t *transitions.Config) (*Machine, error) { return nil, err } + broadcastManager := broadcast.NewManager(handler) + + err = registry.RegisterPostTransitionHook(hooks.PostTransitionHookConfig{ + Name: "broadcast", + From: []string{"*"}, + To: []string{"*"}, + Action: broadcastManager.BroadcastHook, + }) + if err != nil { + return nil, err + } + f, err := fsm.New( StatusNew, t, fsm.WithLogHandler(handler), fsm.WithCallbackRegistry(registry), - fsm.WithBroadcastTimeout(5*time.Second), ) if err != nil { return nil, err } - return &Machine{Machine: f}, nil + return &Machine{ + Machine: f, + broadcastManager: broadcastManager, + }, nil } // NewTypicalFSM creates a new finite state machine with standard transitions.