Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions runnables/composite/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (r *Runner[T]) Run(ctx context.Context) error {
}

// Start all child runnables
if err := r.boot(ctx); err != nil {
if err := r.boot(runCtx); err != nil {
r.setStateError()
return fmt.Errorf("failed to start child runnables: %w", err)
}
Expand All @@ -128,7 +128,8 @@ func (r *Runner[T]) Run(ctx context.Context) error {
r.logger.Debug("Local context canceled")
case err := <-r.serverErrors:
r.setStateError()
return fmt.Errorf("%w: %w", ErrRunnableFailed, err)
stopErr := r.stopAllRunnables()
return fmt.Errorf("%w: %w", ErrRunnableFailed, errors.Join(err, stopErr))
}

if err := r.fsm.TransitionIfCurrentState(finitestate.StatusRunning, finitestate.StatusStopping); err != nil {
Expand Down
124 changes: 124 additions & 0 deletions runnables/composite/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ func TestCompositeRunner_Run(t *testing.T) {
// Wait until context is cancelled
<-args.Get(0).(context.Context).Done()
}).Return(nil)
mockRunnable1.On("Stop").Once()

// Create entries
entries := []RunnableEntry[*mocks.Runnable]{
Expand Down Expand Up @@ -235,6 +236,7 @@ func TestCompositeRunner_Run(t *testing.T) {
mockRunnable := mocks.NewMockRunnable()
mockRunnable.On("String").Return("runnable1")
mockRunnable.On("Run", mock.Anything).Return(errors.New("failed to start"))
mockRunnable.On("Stop").Once()

entries := []RunnableEntry[*mocks.Runnable]{
{Runnable: mockRunnable, Config: nil},
Expand Down Expand Up @@ -498,6 +500,125 @@ func TestCompositeRunner_Stop(t *testing.T) {
})
}

func TestCompositeRunner_StopCancelsChildContexts(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) {
childDone := make(chan struct{}, 2)

mock1 := mocks.NewMockRunnable()
mock1.On("String").Return("r1")
mock1.On("Run", mock.Anything).Run(func(args mock.Arguments) {
<-args.Get(0).(context.Context).Done()
childDone <- struct{}{}
}).Return(nil)
mock1.On("Stop").Once()

mock2 := mocks.NewMockRunnable()
mock2.On("String").Return("r2")
mock2.On("Run", mock.Anything).Run(func(args mock.Arguments) {
<-args.Get(0).(context.Context).Done()
childDone <- struct{}{}
}).Return(nil)
mock2.On("Stop").Once()

entries := []RunnableEntry[*mocks.Runnable]{
{Runnable: mock1},
{Runnable: mock2},
}

configCallback := func() (*Config[*mocks.Runnable], error) {
return NewConfig("test", entries)
}

runner, err := NewRunner(configCallback)
require.NoError(t, err)

errCh := make(chan error, 1)
go func() {
errCh <- runner.Run(t.Context())
}()

time.Sleep(10 * time.Millisecond)
synctest.Wait()

assert.Equal(t, finitestate.StatusRunning, runner.GetState())

runner.Stop()

time.Sleep(10 * time.Millisecond)
synctest.Wait()

select {
case err := <-errCh:
require.NoError(t, err)
default:
t.Fatal("Run did not return after Stop — children may not derive from runCtx")
}

for range 2 {
select {
case <-childDone:
default:
t.Fatal("child goroutine did not exit after Stop")
}
}

mock1.AssertExpectations(t)
mock2.AssertExpectations(t)
})
}

func TestCompositeRunner_ErrorPathStopsSiblings(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) {
survivor := mocks.NewMockRunnable()
survivor.On("String").Return("survivor")
survivor.On("Run", mock.Anything).Run(func(args mock.Arguments) {
<-args.Get(0).(context.Context).Done()
}).Return(nil)
survivor.On("Stop").Once()

failer := mocks.NewMockRunnable()
failer.On("String").Return("failer")
failer.On("Run", mock.Anything).Run(func(args mock.Arguments) {
time.Sleep(50 * time.Millisecond)
}).Return(errors.New("boom"))
failer.On("Stop").Once()

entries := []RunnableEntry[*mocks.Runnable]{
{Runnable: survivor},
{Runnable: failer},
}

configCallback := func() (*Config[*mocks.Runnable], error) {
return NewConfig("test", entries)
}

runner, err := NewRunner(configCallback)
require.NoError(t, err)

errCh := make(chan error, 1)
go func() {
errCh <- runner.Run(t.Context())
}()

time.Sleep(60 * time.Millisecond)
synctest.Wait()

select {
case err := <-errCh:
require.Error(t, err)
require.ErrorIs(t, err, ErrRunnableFailed)
require.ErrorContains(t, err, "boom")
default:
t.Fatal("Run should have returned error from failing child")
}

survivor.AssertExpectations(t)
failer.AssertExpectations(t)
})
}

func TestCompositeRunner_MultipleChildFailures(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) {
Expand All @@ -510,20 +631,23 @@ func TestCompositeRunner_MultipleChildFailures(t *testing.T) {
started <- struct{}{}
time.Sleep(20 * time.Millisecond)
}).Return(failErr)
mockRunnable1.On("Stop").Once()

mockRunnable2 := mocks.NewMockRunnable()
mockRunnable2.On("String").Return("failer2")
mockRunnable2.On("Run", mock.Anything).Run(func(args mock.Arguments) {
started <- struct{}{}
time.Sleep(20 * time.Millisecond)
}).Return(failErr)
mockRunnable2.On("Stop").Once()

mockRunnable3 := mocks.NewMockRunnable()
mockRunnable3.On("String").Return("failer3")
mockRunnable3.On("Run", mock.Anything).Run(func(args mock.Arguments) {
started <- struct{}{}
time.Sleep(20 * time.Millisecond)
}).Return(failErr)
mockRunnable3.On("Stop").Once()

entries := []RunnableEntry[*mocks.Runnable]{
{Runnable: mockRunnable1},
Expand Down