diff --git a/runnables/composite/runner.go b/runnables/composite/runner.go index c759c8c..4ca1c87 100644 --- a/runnables/composite/runner.go +++ b/runnables/composite/runner.go @@ -111,7 +111,7 @@ func (r *Runner[T]) Run(ctx context.Context) error { } // Start all child runnables - if err := r.boot(ctx); err != nil { + if err := r.boot(runCtx); err != nil { r.setStateError() return fmt.Errorf("failed to start child runnables: %w", err) } @@ -128,7 +128,8 @@ func (r *Runner[T]) Run(ctx context.Context) error { r.logger.Debug("Local context canceled") case err := <-r.serverErrors: r.setStateError() - return fmt.Errorf("%w: %w", ErrRunnableFailed, err) + stopErr := r.stopAllRunnables() + return fmt.Errorf("%w: %w", ErrRunnableFailed, errors.Join(err, stopErr)) } if err := r.fsm.TransitionIfCurrentState(finitestate.StatusRunning, finitestate.StatusStopping); err != nil { diff --git a/runnables/composite/runner_test.go b/runnables/composite/runner_test.go index d9c08f1..edb26a6 100644 --- a/runnables/composite/runner_test.go +++ b/runnables/composite/runner_test.go @@ -204,6 +204,7 @@ func TestCompositeRunner_Run(t *testing.T) { // Wait until context is cancelled <-args.Get(0).(context.Context).Done() }).Return(nil) + mockRunnable1.On("Stop").Once() // Create entries entries := []RunnableEntry[*mocks.Runnable]{ @@ -235,6 +236,7 @@ func TestCompositeRunner_Run(t *testing.T) { mockRunnable := mocks.NewMockRunnable() mockRunnable.On("String").Return("runnable1") mockRunnable.On("Run", mock.Anything).Return(errors.New("failed to start")) + mockRunnable.On("Stop").Once() entries := []RunnableEntry[*mocks.Runnable]{ {Runnable: mockRunnable, Config: nil}, @@ -498,6 +500,125 @@ func TestCompositeRunner_Stop(t *testing.T) { }) } +func TestCompositeRunner_StopCancelsChildContexts(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + childDone := make(chan struct{}, 2) + + mock1 := mocks.NewMockRunnable() + mock1.On("String").Return("r1") + mock1.On("Run", mock.Anything).Run(func(args mock.Arguments) { + <-args.Get(0).(context.Context).Done() + childDone <- struct{}{} + }).Return(nil) + mock1.On("Stop").Once() + + mock2 := mocks.NewMockRunnable() + mock2.On("String").Return("r2") + mock2.On("Run", mock.Anything).Run(func(args mock.Arguments) { + <-args.Get(0).(context.Context).Done() + childDone <- struct{}{} + }).Return(nil) + mock2.On("Stop").Once() + + entries := []RunnableEntry[*mocks.Runnable]{ + {Runnable: mock1}, + {Runnable: mock2}, + } + + configCallback := func() (*Config[*mocks.Runnable], error) { + return NewConfig("test", entries) + } + + runner, err := NewRunner(configCallback) + require.NoError(t, err) + + errCh := make(chan error, 1) + go func() { + errCh <- runner.Run(t.Context()) + }() + + time.Sleep(10 * time.Millisecond) + synctest.Wait() + + assert.Equal(t, finitestate.StatusRunning, runner.GetState()) + + runner.Stop() + + time.Sleep(10 * time.Millisecond) + synctest.Wait() + + select { + case err := <-errCh: + require.NoError(t, err) + default: + t.Fatal("Run did not return after Stop — children may not derive from runCtx") + } + + for range 2 { + select { + case <-childDone: + default: + t.Fatal("child goroutine did not exit after Stop") + } + } + + mock1.AssertExpectations(t) + mock2.AssertExpectations(t) + }) +} + +func TestCompositeRunner_ErrorPathStopsSiblings(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + survivor := mocks.NewMockRunnable() + survivor.On("String").Return("survivor") + survivor.On("Run", mock.Anything).Run(func(args mock.Arguments) { + <-args.Get(0).(context.Context).Done() + }).Return(nil) + survivor.On("Stop").Once() + + failer := mocks.NewMockRunnable() + failer.On("String").Return("failer") + failer.On("Run", mock.Anything).Run(func(args mock.Arguments) { + time.Sleep(50 * time.Millisecond) + }).Return(errors.New("boom")) + failer.On("Stop").Once() + + entries := []RunnableEntry[*mocks.Runnable]{ + {Runnable: survivor}, + {Runnable: failer}, + } + + configCallback := func() (*Config[*mocks.Runnable], error) { + return NewConfig("test", entries) + } + + runner, err := NewRunner(configCallback) + require.NoError(t, err) + + errCh := make(chan error, 1) + go func() { + errCh <- runner.Run(t.Context()) + }() + + time.Sleep(60 * time.Millisecond) + synctest.Wait() + + select { + case err := <-errCh: + require.Error(t, err) + require.ErrorIs(t, err, ErrRunnableFailed) + require.ErrorContains(t, err, "boom") + default: + t.Fatal("Run should have returned error from failing child") + } + + survivor.AssertExpectations(t) + failer.AssertExpectations(t) + }) +} + func TestCompositeRunner_MultipleChildFailures(t *testing.T) { t.Parallel() synctest.Test(t, func(t *testing.T) { @@ -510,6 +631,7 @@ func TestCompositeRunner_MultipleChildFailures(t *testing.T) { started <- struct{}{} time.Sleep(20 * time.Millisecond) }).Return(failErr) + mockRunnable1.On("Stop").Once() mockRunnable2 := mocks.NewMockRunnable() mockRunnable2.On("String").Return("failer2") @@ -517,6 +639,7 @@ func TestCompositeRunner_MultipleChildFailures(t *testing.T) { started <- struct{}{} time.Sleep(20 * time.Millisecond) }).Return(failErr) + mockRunnable2.On("Stop").Once() mockRunnable3 := mocks.NewMockRunnable() mockRunnable3.On("String").Return("failer3") @@ -524,6 +647,7 @@ func TestCompositeRunner_MultipleChildFailures(t *testing.T) { started <- struct{}{} time.Sleep(20 * time.Millisecond) }).Return(failErr) + mockRunnable3.On("Stop").Once() entries := []RunnableEntry[*mocks.Runnable]{ {Runnable: mockRunnable1},