Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions internal/finitestate/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@ const (
StatusUnknown = transitions.StatusUnknown
)

// TypicalTransitions is a set of standard transitions for a finite state machine.
var TypicalTransitions = transitions.Typical
// typicalTransitions is a set of standard transitions for a finite state machine.
var typicalTransitions = transitions.Typical

// Machine is a wrapper around go-fsm v2 that provides the v1 API compatibility.
// It manages both the FSM and broadcast functionality.
// Machine wraps go-fsm v2 to provide a simplified API with broadcast support.
type Machine struct {
*fsm.Machine
broadcastManager *broadcast.Manager
}

// GetStateChan returns a channel that emits the state whenever it changes.
// The channel is closed when the provided context is canceled.
// For v1 API compatibility, the current state is sent immediately to the channel.
// A 5-second broadcast timeout is used to prevent slow consumers from blocking state updates.
// The current state is sent immediately. The channel is closed when the
// provided context is canceled.
// A 5-second broadcast timeout prevents slow consumers from blocking state updates.
func (s *Machine) GetStateChan(ctx context.Context) <-chan string {
return s.getStateChanInternal(ctx, broadcast.WithTimeout(5*time.Second))
}

// getStateChanInternal is a helper that creates a channel and sends the current state to it.
// This maintains v1 API compatibility where GetStateChan immediately sends the current state.
// getStateChanInternal subscribes to state changes via the broadcast manager
// and sends the current state immediately for compatibility with callers that
// expect the initial state on the channel.
func (s *Machine) getStateChanInternal(ctx context.Context, opts ...broadcast.Option) <-chan string {
wrappedCh := make(chan string, 1)

Expand All @@ -64,12 +64,11 @@ func (s *Machine) getStateChanInternal(ctx context.Context, opts ...broadcast.Op
return wrappedCh
}

// New creates a new finite state machine with the specified logger using "standard" state transitions.
// This function provides compatibility with the v1 API while using v2 under the hood.
func New(handler slog.Handler) (*Machine, error) {
// New creates a new finite state machine with the specified logger and transitions.
func New(handler slog.Handler, t *transitions.Config) (*Machine, error) {
registry, err := hooks.NewRegistry(
hooks.WithLogHandler(handler),
hooks.WithTransitions(TypicalTransitions),
hooks.WithTransitions(t),
)
if err != nil {
return nil, err
Expand All @@ -89,7 +88,7 @@ func New(handler slog.Handler) (*Machine, error) {

f, err := fsm.New(
StatusNew,
TypicalTransitions,
t,
fsm.WithLogHandler(handler),
fsm.WithCallbackRegistry(registry),
)
Expand All @@ -102,3 +101,8 @@ func New(handler slog.Handler) (*Machine, error) {
broadcastManager: broadcastManager,
}, nil
}

// NewTypicalFSM creates a new finite state machine with standard transitions.
func NewTypicalFSM(handler slog.Handler) (*Machine, error) {
return New(handler, typicalTransitions)
}
10 changes: 5 additions & 5 deletions internal/finitestate/machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestNew(t *testing.T) {

t.Run("creates new machine with correct initial state", func(t *testing.T) {
handler := slog.NewTextHandler(os.Stdout, nil)
machine, err := New(handler)
machine, err := New(handler, typicalTransitions)

require.NoError(t, err)
require.NotNil(t, machine)
Expand All @@ -26,7 +26,7 @@ func TestNew(t *testing.T) {
t.Run("uses provided handler", func(t *testing.T) {
// Create a test handler
handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})
machine, err := New(handler)
machine, err := New(handler, typicalTransitions)

require.NoError(t, err)
require.NotNil(t, machine)
Expand All @@ -38,7 +38,7 @@ func TestMachineInterface(t *testing.T) {

setup := func() *Machine {
handler := slog.NewTextHandler(os.Stdout, nil)
m, err := New(handler)
m, err := New(handler, typicalTransitions)
require.NoError(t, err)
return m
}
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestTypicalTransitions(t *testing.T) {
t.Parallel()

t.Run("verify TypicalTransitions is not nil", func(t *testing.T) {
assert.NotNil(t, TypicalTransitions)
assert.NotNil(t, typicalTransitions)
})

t.Run("verify status constants are defined", func(t *testing.T) {
Expand All @@ -203,7 +203,7 @@ func TestGetStateChanWithTimeout(t *testing.T) {

setup := func() *Machine {
handler := slog.NewTextHandler(os.Stdout, nil)
m, err := New(handler)
m, err := New(handler, typicalTransitions)
require.NoError(t, err)
return m
}
Expand Down
28 changes: 13 additions & 15 deletions runnables/composite/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync/atomic"

"github.com/robbyt/go-supervisor/internal/finitestate"
"github.com/robbyt/go-supervisor/supervisor/lifecycle"
)

// ConfigCallback is the function type signature for the callback used to load initial config, and new config during Reload()
Expand All @@ -26,6 +27,7 @@ type fsm interface {
// as a single unit. It satisfies the Runnable, Reloadable, and Stateable interfaces.
type Runner[T runnable] struct {
fsm fsm
lc *lifecycle.StartStop
configMu sync.Mutex // Only used for getConfig()
currentConfig atomic.Pointer[Config[T]]
configCallback ConfigCallback[T]
Expand Down Expand Up @@ -53,6 +55,7 @@ func NewRunner[T runnable](
) (*Runner[T], error) {
logger := slog.Default().WithGroup("composite.Runner")
r := &Runner[T]{
lc: lifecycle.New(),
currentConfig: atomic.Pointer[Config[T]]{},
configCallback: configCallback,
serverErrors: make(chan error, 1),
Expand All @@ -73,9 +76,7 @@ func NewRunner[T runnable](
}

// Create FSM after the optional logger has been configured
fsm, err := finitestate.New(
r.logger.WithGroup("fsm").Handler(),
)
fsm, err := finitestate.NewTypicalFSM(r.logger.WithGroup("fsm").Handler())
if err != nil {
return nil, fmt.Errorf("unable to create fsm: %w", err)
}
Expand All @@ -99,7 +100,9 @@ func (r *Runner[T]) Run(ctx context.Context) error {
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()

// store the Run context and cancel function in the runner so that Reload() and Stop() can use them later
done := r.lc.Started()
defer done()

r.runnablesMu.Lock()
r.ctx = runCtx
r.cancel = runCancel
Expand All @@ -122,10 +125,13 @@ func (r *Runner[T]) Run(ctx context.Context) error {
return fmt.Errorf("failed to transition to Running state: %w", err)
}

// Wait for context cancellation or errors
// Wait for context cancellation, stop signal, or errors
select {
case <-runCtx.Done():
r.logger.Debug("Local context canceled")
case <-r.lc.StopCh():
r.logger.Debug("Stop() called")
runCancel()
case err := <-r.serverErrors:
r.setStateError()
stopErr := r.stopAllRunnables()
Expand Down Expand Up @@ -153,17 +159,9 @@ func (r *Runner[T]) Run(ctx context.Context) error {
return nil
}

// Stop will cancel the context, causing all child runnables to stop.
// Stop signals the runner to shut down and blocks until Run() completes.
func (r *Runner[T]) Stop() {
r.runnablesMu.Lock()
cancel := r.cancel
r.runnablesMu.Unlock()

if cancel == nil {
r.logger.Warn("Cancel function is nil, skipping Stop")
return
}
cancel()
r.lc.Stop()
}

// boot starts all child runnables in the order they're defined.
Expand Down
165 changes: 114 additions & 51 deletions runnables/composite/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"log/slog"
"os"
"sync"
"sync/atomic"
"testing"
"testing/synctest"
Expand Down Expand Up @@ -427,76 +428,138 @@ func TestCompositeRunner_Run(t *testing.T) {
func TestCompositeRunner_Stop(t *testing.T) {
t.Parallel()

// Create reusable mock setup function
setupMocksAndConfig := func() (entries []RunnableEntry[*mocks.Runnable], configFunc func() (*Config[*mocks.Runnable], error)) {
mockRunnable1 := mocks.NewMockRunnable()
mockRunnable2 := mocks.NewMockRunnable()
t.Run("stop blocks until run completes", func(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) {
mock1 := mocks.NewMockRunnable()
mock1.On("String").Return("r1")
mock1.On("Run", mock.Anything).Run(func(args mock.Arguments) {
<-args.Get(0).(context.Context).Done()
}).Return(nil)
mock1.On("Stop").Once()

entries := []RunnableEntry[*mocks.Runnable]{
{Runnable: mock1},
}

entries = []RunnableEntry[*mocks.Runnable]{
{Runnable: mockRunnable1, Config: nil},
{Runnable: mockRunnable2, Config: nil},
}
configCallback := func() (*Config[*mocks.Runnable], error) {
return NewConfig("test", entries)
}

// Create config callback
configFunc = func() (*Config[*mocks.Runnable], error) {
return NewConfig("test", entries)
}
runner, err := NewRunner(configCallback)
require.NoError(t, err)

return entries, configFunc
}
errCh := make(chan error, 1)
runReturned := atomic.Bool{}
go func() {
errCh <- runner.Run(t.Context())
runReturned.Store(true)
}()

time.Sleep(10 * time.Millisecond)
synctest.Wait()
assert.Equal(t, finitestate.StatusRunning, runner.GetState())

runner.Stop()

assert.True(t, runReturned.Load(), "Run should have returned before Stop unblocked")
assert.Equal(t, finitestate.StatusStopped, runner.GetState())
require.NoError(t, <-errCh)

mock1.AssertExpectations(t)
})
})

t.Run("stop from running state", func(t *testing.T) {
t.Run("stop before run", func(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) {
mock1 := mocks.NewMockRunnable()
mock1.On("String").Return("r1")
mock1.On("Run", mock.Anything).Run(func(args mock.Arguments) {
<-args.Get(0).(context.Context).Done()
}).Return(nil)
mock1.On("Stop").Once()

entries := []RunnableEntry[*mocks.Runnable]{
{Runnable: mock1},
}

// Setup mock runnables and config
_, configCallback := setupMocksAndConfig()
configCallback := func() (*Config[*mocks.Runnable], error) {
return NewConfig("test", entries)
}

// Create runner
runner, err := NewRunner(configCallback)
require.NoError(t, err)
runner, err := NewRunner(configCallback)
require.NoError(t, err)

// Set up cancel function as Run() would
ctx, cancel := context.WithCancel(t.Context())
runner.runnablesMu.Lock()
runner.ctx = ctx
runner.cancel = cancel
runner.runnablesMu.Unlock()
stopReturned := atomic.Bool{}
go func() {
runner.Stop()
stopReturned.Store(true)
}()

err = runner.fsm.SetState(finitestate.StatusRunning)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond)
synctest.Wait()

// Call Stop - should just cancel the context, not change state
runner.Stop()
assert.False(t, stopReturned.Load(), "Stop should block until Run starts and completes")

// Verify state did not change (Stop only cancels context)
assert.Equal(t, finitestate.StatusRunning, runner.GetState())
errCh := make(chan error, 1)
go func() {
errCh <- runner.Run(t.Context())
}()

// Verify context was cancelled
select {
case <-ctx.Done():
// Good, context was cancelled
default:
t.Error("Context should be cancelled after Stop()")
}
time.Sleep(10 * time.Millisecond)
synctest.Wait()

assert.True(t, stopReturned.Load(), "Stop should unblock after Run completes")
require.NoError(t, <-errCh)

mock1.AssertExpectations(t)
})
})

t.Run("stop from non-running state", func(t *testing.T) {
t.Run("multiple stop calls", func(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) {
mock1 := mocks.NewMockRunnable()
mock1.DelayStop = 0
mock1.On("String").Return("r1")
mock1.On("Run", mock.Anything).Run(func(args mock.Arguments) {
<-args.Get(0).(context.Context).Done()
}).Return(nil)
mock1.On("Stop").Once()

entries := []RunnableEntry[*mocks.Runnable]{
{Runnable: mock1},
}

// Setup mock runnables and config
_, configCallback := setupMocksAndConfig()
configCallback := func() (*Config[*mocks.Runnable], error) {
return NewConfig("test", entries)
}

// Create runner and manually set state to Stopped
runner, err := NewRunner(configCallback)
require.NoError(t, err)
err = runner.fsm.SetState(finitestate.StatusStopped)
require.NoError(t, err)
runner, err := NewRunner(configCallback)
require.NoError(t, err)

// Call Stop
runner.Stop()
errCh := make(chan error, 1)
go func() {
errCh <- runner.Run(t.Context())
}()

// Verify state did not change
assert.Equal(t, finitestate.StatusStopped, runner.GetState())
time.Sleep(10 * time.Millisecond)
synctest.Wait()
assert.Equal(t, finitestate.StatusRunning, runner.GetState())

var wg sync.WaitGroup
for range 5 {
wg.Go(func() {
runner.Stop()
})
}
wg.Wait()

assert.Equal(t, finitestate.StatusStopped, runner.GetState())
require.NoError(t, <-errCh)
mock1.AssertExpectations(t)
})
})
}

Expand Down
Loading
Loading