From f22bace47c0c959c9d3e44d7ccf467e9a654e9e0 Mon Sep 17 00:00:00 2001 From: Robert Terhaar Date: Fri, 6 Feb 2026 12:53:54 -0500 Subject: [PATCH] move tests to synctest, remove sleeps --- runnables/composite/integration_race_test.go | 331 ++++++++----------- runnables/composite/runner_test.go | 106 +++--- runnables/httpserver/state_test.go | 1 - supervisor/reload_test.go | 39 +-- supervisor/shutdown_test.go | 170 +++------- supervisor/state_deduplication_test.go | 218 ++++++------ supervisor/state_monitoring_test.go | 122 +++---- 7 files changed, 402 insertions(+), 585 deletions(-) diff --git a/runnables/composite/integration_race_test.go b/runnables/composite/integration_race_test.go index a66705f..2716891 100644 --- a/runnables/composite/integration_race_test.go +++ b/runnables/composite/integration_race_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "testing/synctest" "time" "github.com/robbyt/go-supervisor/runnables/mocks" @@ -20,7 +21,7 @@ func TestIntegration_CompositeNoRaceCondition(t *testing.T) { t.Skip("Skipping integration test in short mode") } - const iterations = 5 // Test multiple times to catch race conditions + const iterations = 5 for i := 0; i < iterations; i++ { t.Run(fmt.Sprintf("iteration_%d", i), func(t *testing.T) { @@ -31,113 +32,95 @@ func TestIntegration_CompositeNoRaceCondition(t *testing.T) { func testCompositeRaceCondition(t *testing.T) { t.Helper() - // Channel to signal when Run() methods are called - runCalled := make(chan struct{}, 3) - - // Create mock runnables using the mocks package - mock1 := mocks.NewMockRunnableWithStateable() - mock2 := mocks.NewMockRunnableWithStateable() - mock3 := mocks.NewMockRunnableWithStateable() - mock1.On("String").Return("service1") - mock1.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { - runCalled <- struct{}{} // Signal that Run was called - ctx := args.Get(0).(context.Context) - <-ctx.Done() // Block until cancelled like a real service - }) - mock1.On("Stop").Return() - mock1.On("IsRunning").Return(true) - mock1.On("GetState").Return("Running") - - mock2.On("String").Return("service2") - mock2.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { - runCalled <- struct{}{} // Signal that Run was called - ctx := args.Get(0).(context.Context) - <-ctx.Done() // Block until cancelled like a real service - }) - mock2.On("Stop").Return() - mock2.On("IsRunning").Return(true) - mock2.On("GetState").Return("Running") - - mock3.On("String").Return("service3") - mock3.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { - runCalled <- struct{}{} // Signal that Run was called - ctx := args.Get(0).(context.Context) - <-ctx.Done() // Block until cancelled like a real service - }) - mock3.On("Stop").Return() - mock3.On("IsRunning").Return(true) - mock3.On("GetState").Return("Running") - - // Create config callback - callback := func() (*Config[*mocks.MockRunnableWithStateable], error) { - return NewConfig("test-composite", []RunnableEntry[*mocks.MockRunnableWithStateable]{ - {Runnable: mock1}, - {Runnable: mock2}, - {Runnable: mock3}, + synctest.Test(t, func(t *testing.T) { + runCalled := make(chan struct{}, 3) + + mock1 := mocks.NewMockRunnableWithStateable() + mock2 := mocks.NewMockRunnableWithStateable() + mock3 := mocks.NewMockRunnableWithStateable() + mock1.On("String").Return("service1") + mock1.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + runCalled <- struct{}{} + ctx := args.Get(0).(context.Context) + <-ctx.Done() }) - } + mock1.On("Stop").Return() + mock1.On("IsRunning").Return(true) + mock1.On("GetState").Return("Running") + + mock2.On("String").Return("service2") + mock2.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + runCalled <- struct{}{} + ctx := args.Get(0).(context.Context) + <-ctx.Done() + }) + mock2.On("Stop").Return() + mock2.On("IsRunning").Return(true) + mock2.On("GetState").Return("Running") + + mock3.On("String").Return("service3") + mock3.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + runCalled <- struct{}{} + ctx := args.Get(0).(context.Context) + <-ctx.Done() + }) + mock3.On("Stop").Return() + mock3.On("IsRunning").Return(true) + mock3.On("GetState").Return("Running") + + callback := func() (*Config[*mocks.MockRunnableWithStateable], error) { + return NewConfig("test-composite", []RunnableEntry[*mocks.MockRunnableWithStateable]{ + {Runnable: mock1}, + {Runnable: mock2}, + {Runnable: mock3}, + }) + } - // Create composite runner with real FSM - runner, err := NewRunner(callback) - require.NoError(t, err) + runner, err := NewRunner(callback) + require.NoError(t, err) - ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() - // Start runner - runErr := make(chan error, 1) - go func() { - runErr <- runner.Run(ctx) - }() + runErr := make(chan error, 1) + go func() { + runErr <- runner.Run(ctx) + }() - // Wait for IsRunning() to return true - require.Eventually(t, func() bool { - return runner.IsRunning() - }, 5*time.Second, 50*time.Millisecond, "Composite should report as running") + // Advance virtual clock past the mock's default 1ms Run delay + time.Sleep(10 * time.Millisecond) + synctest.Wait() - // Wait for all Run() methods to be called using channel synchronization - for i := 0; i < 3; i++ { - select { - case <-runCalled: - // One Run() method was called - case <-time.After(5 * time.Second): - t.Fatalf("Timeout waiting for Run() calls, only received %d of 3", i) - } - } - - // CRITICAL TEST: When composite reports running, all children should be running - assert.True(t, mock1.IsRunning(), - "RACE CONDITION: Composite reports running but child 1 not running") - assert.True(t, mock2.IsRunning(), - "RACE CONDITION: Composite reports running but child 2 not running") - assert.True(t, mock3.IsRunning(), - "RACE CONDITION: Composite reports running but child 3 not running") + assert.True(t, runner.IsRunning(), "Composite should report as running") + assert.Len(t, runCalled, 3, "All 3 Run() methods should have been called") - // Test child states through composite - childStates := runner.GetChildStates() - assert.Len(t, childStates, 3, "Should have 3 child states") + assert.True(t, mock1.IsRunning(), + "RACE CONDITION: Composite reports running but child 1 not running") + assert.True(t, mock2.IsRunning(), + "RACE CONDITION: Composite reports running but child 2 not running") + assert.True(t, mock3.IsRunning(), + "RACE CONDITION: Composite reports running but child 3 not running") - for name, state := range childStates { - assert.Equal(t, "Running", state, "Child %s should be running", name) - } + childStates := runner.GetChildStates() + assert.Len(t, childStates, 3, "Should have 3 child states") + for name, state := range childStates { + assert.Equal(t, "Running", state, "Child %s should be running", name) + } - // Stop the runner - cancel() + cancel() + synctest.Wait() - // Wait for shutdown - timeoutCtx, timeoutCancel := context.WithTimeout(t.Context(), 5*time.Second) - defer timeoutCancel() - select { - case err := <-runErr: - require.NoError(t, err) - case <-timeoutCtx.Done(): - t.Fatal("Composite did not shutdown within timeout") - } + select { + case err := <-runErr: + require.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("Composite did not shutdown within timeout") + } - // Verify Stop() was called on all mocks - mock1.AssertCalled(t, "Stop") - mock2.AssertCalled(t, "Stop") - mock3.AssertCalled(t, "Stop") + mock1.AssertCalled(t, "Stop") + mock2.AssertCalled(t, "Stop") + mock3.AssertCalled(t, "Stop") + }) } // TestIntegration_CompositeFullLifecycle tests complete composite lifecycle @@ -145,101 +128,75 @@ func TestIntegration_CompositeFullLifecycle(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } + synctest.Test(t, func(t *testing.T) { + mock1 := mocks.NewMockRunnableWithStateable() + mock2 := mocks.NewMockRunnableWithStateable() + + mock1.On("String").Return("mock-service-1") + mock1.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + ctx := args.Get(0).(context.Context) + <-ctx.Done() + }) + mock1.On("Stop").Return() + mock1.On("GetState").Return("Running") - // Create mock runnables for testing - mock1 := mocks.NewMockRunnableWithStateable() - mock2 := mocks.NewMockRunnableWithStateable() - - // Set up mock expectations for normal operation - mock1.On("String").Return("mock-service-1") - mock1.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { - ctx := args.Get(0).(context.Context) - <-ctx.Done() // Block until cancelled like a real service - }) - mock1.On("Stop").Return() - mock1.On("GetState").Return("Running") - - mock2.On("String").Return("mock-service-2") - mock2.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { - ctx := args.Get(0).(context.Context) - <-ctx.Done() // Block until cancelled like a real service - }) - mock2.On("Stop").Return() - mock2.On("GetState").Return("Running") - - // Create config - callback := func() (*Config[*mocks.MockRunnableWithStateable], error) { - return NewConfig("integration-test", []RunnableEntry[*mocks.MockRunnableWithStateable]{ - {Runnable: mock1}, - {Runnable: mock2}, + mock2.On("String").Return("mock-service-2") + mock2.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + ctx := args.Get(0).(context.Context) + <-ctx.Done() }) - } + mock2.On("Stop").Return() + mock2.On("GetState").Return("Running") + + callback := func() (*Config[*mocks.MockRunnableWithStateable], error) { + return NewConfig("integration-test", []RunnableEntry[*mocks.MockRunnableWithStateable]{ + {Runnable: mock1}, + {Runnable: mock2}, + }) + } - // Create runner - runner, err := NewRunner(callback) - require.NoError(t, err) - - // Initial state - assert.Equal(t, "New", runner.GetState()) - assert.False(t, runner.IsRunning()) - - ctx, cancel := context.WithCancel(t.Context()) - - // Start runner - runErr := make(chan error, 1) - go func() { - runErr <- runner.Run(ctx) - }() - - // Should transition through states - assert.Eventually(t, func() bool { - state := runner.GetState() - return state == "Booting" || state == "Running" - }, 2*time.Second, 50*time.Millisecond, "Should transition to Booting") - - // Wait for Running state - assert.Eventually(t, func() bool { - return runner.IsRunning() && runner.GetState() == "Running" - }, 5*time.Second, 50*time.Millisecond, "Should transition to Running") - - // Verify all mocks received Run() calls - time.Sleep(100 * time.Millisecond) // Give time for calls to register - mock1.AssertCalled(t, "Run", mock.Anything) - mock2.AssertCalled(t, "Run", mock.Anything) - - // Test child states - childStates := runner.GetChildStates() - assert.Len(t, childStates, 2) - assert.Equal(t, "Running", childStates["mock-service-1"]) - assert.Equal(t, "Running", childStates["mock-service-2"]) - - // Stop runner - cancel() - - // Should transition to stopping/stopped - assert.Eventually(t, func() bool { - state := runner.GetState() - return state == "Stopping" || state == "Stopped" - }, 2*time.Second, 50*time.Millisecond, "Should transition to Stopping") - - // Wait for shutdown - timeoutCtx, timeoutCancel := context.WithTimeout(t.Context(), 5*time.Second) - defer timeoutCancel() - select { - case err := <-runErr: + runner, err := NewRunner(callback) require.NoError(t, err) - case <-timeoutCtx.Done(): - t.Fatal("Runner did not shutdown within timeout") - } - // Final state - assert.Eventually(t, func() bool { - return runner.GetState() == "Stopped" - }, 1*time.Second, 10*time.Millisecond, "Should be Stopped") + assert.Equal(t, "New", runner.GetState()) + assert.False(t, runner.IsRunning()) + + ctx, cancel := context.WithCancel(t.Context()) + + runErr := make(chan error, 1) + go func() { + runErr <- runner.Run(ctx) + }() - assert.False(t, runner.IsRunning()) + // Advance virtual clock past the mock's default 1ms Run delay + time.Sleep(10 * time.Millisecond) + synctest.Wait() - // Verify Stop() was called on all mocks - mock1.AssertCalled(t, "Stop") - mock2.AssertCalled(t, "Stop") + assert.True(t, runner.IsRunning(), "Should be Running") + assert.Equal(t, "Running", runner.GetState()) + + mock1.AssertCalled(t, "Run", mock.Anything) + mock2.AssertCalled(t, "Run", mock.Anything) + + childStates := runner.GetChildStates() + assert.Len(t, childStates, 2) + assert.Equal(t, "Running", childStates["mock-service-1"]) + assert.Equal(t, "Running", childStates["mock-service-2"]) + + cancel() + synctest.Wait() + + select { + case err := <-runErr: + require.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("Runner did not shutdown within timeout") + } + + assert.Equal(t, "Stopped", runner.GetState()) + assert.False(t, runner.IsRunning()) + + mock1.AssertCalled(t, "Stop") + mock2.AssertCalled(t, "Stop") + }) } diff --git a/runnables/composite/runner_test.go b/runnables/composite/runner_test.go index b1417e0..a3c630a 100644 --- a/runnables/composite/runner_test.go +++ b/runnables/composite/runner_test.go @@ -7,6 +7,7 @@ import ( "os" "sync/atomic" "testing" + "testing/synctest" "time" "github.com/robbyt/go-supervisor/internal/finitestate" @@ -531,73 +532,74 @@ func TestCompositeRunner_Stop(t *testing.T) { func TestCompositeRunner_MultipleChildFailures(t *testing.T) { t.Parallel() + synctest.Test(t, func(t *testing.T) { + failErr := errors.New("child failed") + started := make(chan struct{}) - failErr := errors.New("child failed") - started := make(chan struct{}) + mockRunnable1 := mocks.NewMockRunnable() + mockRunnable1.On("String").Return("failer1").Maybe() + mockRunnable1.On("Stop").Maybe() + mockRunnable1.On("Run", mock.Anything).Run(func(args mock.Arguments) { + started <- struct{}{} + time.Sleep(20 * time.Millisecond) + }).Return(failErr) - mockRunnable1 := mocks.NewMockRunnable() - mockRunnable1.On("String").Return("failer1").Maybe() - mockRunnable1.On("Stop").Maybe() - mockRunnable1.On("Run", mock.Anything).Run(func(args mock.Arguments) { - started <- struct{}{} - time.Sleep(20 * time.Millisecond) - // Return real error (goes through startRunnable's error path) - }).Return(failErr) + mockRunnable2 := mocks.NewMockRunnable() + mockRunnable2.On("String").Return("failer2").Maybe() + mockRunnable2.On("Stop").Maybe() + mockRunnable2.On("Run", mock.Anything).Run(func(args mock.Arguments) { + started <- struct{}{} + time.Sleep(20 * time.Millisecond) + }).Return(failErr) + + mockRunnable3 := mocks.NewMockRunnable() + mockRunnable3.On("String").Return("failer3").Maybe() + mockRunnable3.On("Stop").Maybe() + mockRunnable3.On("Run", mock.Anything).Run(func(args mock.Arguments) { + started <- struct{}{} + time.Sleep(20 * time.Millisecond) + }).Return(failErr) - mockRunnable2 := mocks.NewMockRunnable() - mockRunnable2.On("String").Return("failer2").Maybe() - mockRunnable2.On("Stop").Maybe() - mockRunnable2.On("Run", mock.Anything).Run(func(args mock.Arguments) { - started <- struct{}{} - time.Sleep(20 * time.Millisecond) - }).Return(failErr) - - mockRunnable3 := mocks.NewMockRunnable() - mockRunnable3.On("String").Return("failer3").Maybe() - mockRunnable3.On("Stop").Maybe() - mockRunnable3.On("Run", mock.Anything).Run(func(args mock.Arguments) { - started <- struct{}{} - time.Sleep(20 * time.Millisecond) - }).Return(failErr) + entries := []RunnableEntry[*mocks.Runnable]{ + {Runnable: mockRunnable1}, + {Runnable: mockRunnable2}, + {Runnable: mockRunnable3}, + } - entries := []RunnableEntry[*mocks.Runnable]{ - {Runnable: mockRunnable1}, - {Runnable: mockRunnable2}, - {Runnable: mockRunnable3}, - } + configCallback := func() (*Config[*mocks.Runnable], error) { + return NewConfig("test", entries) + } - configCallback := func() (*Config[*mocks.Runnable], error) { - return NewConfig("test", entries) - } + runner, err := NewRunner(configCallback) + require.NoError(t, err) - runner, err := NewRunner(configCallback) - require.NoError(t, err) + assert.Equal(t, 1, cap(runner.serverErrors), "initial capacity should be 1") - // Channel capacity should grow in boot() to match entry count - assert.Equal(t, 1, cap(runner.serverErrors), "initial capacity should be 1") + runErr := make(chan error, 1) + go func() { + runErr <- runner.Run(t.Context()) + }() - runErr := make(chan error, 1) - go func() { - runErr <- runner.Run(t.Context()) - }() + // Advance virtual clock past the mock's default 1ms Run delay + time.Sleep(10 * time.Millisecond) + synctest.Wait() - // Wait for all children to start - for range 3 { - <-started - } + for range 3 { + <-started + } + + assert.Equal(t, 3, cap(runner.serverErrors), "capacity should match entry count after boot") - // Verify channel was resized - assert.Equal(t, 3, cap(runner.serverErrors), "capacity should match entry count after boot") + // Advance virtual clock past 20ms sleep in callbacks so children return errors + time.Sleep(30 * time.Millisecond) + synctest.Wait() - // Run() should return with the first error - require.Eventually(t, func() bool { select { case err := <-runErr: require.Error(t, err) require.ErrorIs(t, err, ErrRunnableFailed) - return true default: - return false + t.Fatal("runner should return an error from failing children") } - }, 2*time.Second, 10*time.Millisecond, "runner should return an error from failing children") + }) } diff --git a/runnables/httpserver/state_test.go b/runnables/httpserver/state_test.go index 31334b9..b375b79 100644 --- a/runnables/httpserver/state_test.go +++ b/runnables/httpserver/state_test.go @@ -262,7 +262,6 @@ func TestGetStateChanWithTimeout(t *testing.T) { // Now test that the channel receives state changes go func() { - time.Sleep(50 * time.Millisecond) err := server.fsm.SetState(finitestate.StatusRunning) assert.NoError(t, err) }() diff --git a/supervisor/reload_test.go b/supervisor/reload_test.go index 7a620da..96aacbd 100644 --- a/supervisor/reload_test.go +++ b/supervisor/reload_test.go @@ -25,7 +25,7 @@ func TestPIDZero_ReloadManager(t *testing.T) { sender.On("GetReloadTrigger").Return(reloadTrigger) sender.On("Run", mock.Anything).Return(nil) - sender.On("Reload").Return() + sender.On("Reload").Return().Once() sender.On("Stop").Return() sender.On("GetState").Return("running").Maybe() sender.On("GetStateChan", mock.Anything).Return(stateChan).Maybe() @@ -43,12 +43,9 @@ func TestPIDZero_ReloadManager(t *testing.T) { // Trigger reload reloadTrigger <- struct{}{} - // Allow reload to process - time.Sleep(100 * time.Millisecond) - - // Verify reload was called once - sender.AssertCalled(t, "Reload") - sender.AssertNumberOfCalls(t, "Reload", 1) + require.Eventually(t, func() bool { + return !sender.IsMethodCallable(t, "Reload") + }, 1*time.Second, 10*time.Millisecond) p.Shutdown() <-done @@ -71,8 +68,8 @@ func TestPIDZero_ReloadManager(t *testing.T) { sender1.On("Run", mock.Anything).Return(nil) sender2.On("Run", mock.Anything).Return(nil) - sender1.On("Reload").Return() - sender2.On("Reload").Return() + sender1.On("Reload").Return().Times(2) + sender2.On("Reload").Return().Times(2) sender1.On("Stop").Return() sender2.On("Stop").Return() @@ -96,11 +93,9 @@ func TestPIDZero_ReloadManager(t *testing.T) { reloadTrigger1 <- struct{}{} reloadTrigger2 <- struct{}{} - time.Sleep(100 * time.Millisecond) - - // Expect each service to have Reload() called twice - sender1.AssertNumberOfCalls(t, "Reload", 2) - sender2.AssertNumberOfCalls(t, "Reload", 2) + require.Eventually(t, func() bool { + return !sender1.IsMethodCallable(t, "Reload") && !sender2.IsMethodCallable(t, "Reload") + }, 1*time.Second, 10*time.Millisecond) p.Shutdown() <-done @@ -177,18 +172,16 @@ func TestPIDZero_ReloadManager(t *testing.T) { execDone <- pid0.Run() }() - // Allow time for services to start - time.Sleep(10 * time.Millisecond) + require.Eventually(t, func() bool { + return pid0.ctx.Err() == nil + }, time.Second, 5*time.Millisecond) - // Manually trigger reload via API call pid0.ReloadAll() - // Allow time for reload to complete - time.Sleep(50 * time.Millisecond) - - // Verify both services were reloaded - mockService1.AssertNumberOfCalls(t, "Reload", 1) - mockService2.AssertNumberOfCalls(t, "Reload", 1) + require.Eventually(t, func() bool { + return !mockService1.IsMethodCallable(t, "Reload") && + !mockService2.IsMethodCallable(t, "Reload") + }, 1*time.Second, 10*time.Millisecond) // Shutdown and wait for completion pid0.Shutdown() diff --git a/supervisor/shutdown_test.go b/supervisor/shutdown_test.go index 90f4c1c..852dab3 100644 --- a/supervisor/shutdown_test.go +++ b/supervisor/shutdown_test.go @@ -3,6 +3,7 @@ package supervisor import ( "context" "testing" + "testing/synctest" "time" "github.com/robbyt/go-supervisor/runnables/mocks" @@ -13,48 +14,35 @@ import ( // TestPIDZero_StartShutdownManager_TriggersShutdown verifies that receiving a signal // on a ShutdownSender's trigger channel calls the supervisor's Shutdown method. +// Cannot use synctest: the listener goroutine calls Shutdown() which calls wg.Wait() +// on the same WaitGroup that includes startShutdownManager, creating a circular dependency +// that synctest's "all goroutines must complete" requirement turns into a deadlock. func TestPIDZero_StartShutdownManager_TriggersShutdown(t *testing.T) { t.Parallel() - // Create a context with cancel for cleanup and monitoring - // Use a background context for the supervisor initially supervisorCtx, supervisorCancel := context.WithCancel(context.Background()) - defer supervisorCancel() // Ensure cleanup + defer supervisorCancel() - // Create mock service that implements ShutdownSender mockService := mocks.NewMockRunnableWithShutdownSender() - shutdownChan := make(chan struct{}, 1) // Buffered to prevent blocking sender + shutdownChan := make(chan struct{}, 1) mockService.On("GetShutdownTrigger").Return(shutdownChan).Once() mockService.On("String").Return("mockShutdownService").Maybe() mockService.On("Stop").Return().Maybe() - // Create a supervisor with the mock runnable - // Pass the specific context so we can monitor its cancellation pidZero, err := New(WithContext(supervisorCtx), WithRunnables(mockService)) require.NoError(t, err) - // Start the shutdown manager using wg.Go pidZero.wg.Go(pidZero.startShutdownManager) - time.Sleep(200 * time.Millisecond) - - // Send a shutdown signal from the mock runnable shutdownChan <- struct{}{} - // Wait for the supervisor's context to be cancelled, which indicates - // that p.Shutdown() was called by the manager. select { case <-pidZero.ctx.Done(): - // Context was cancelled as expected, Shutdown() was called. case <-time.After(2 * time.Second): t.Fatal("Supervisor context was not cancelled, Shutdown() likely not called") } - // Bypass waiting for the WaitGroup - we've already confirmed - // the Shutdown() was triggered which is the key test assertion - - // Verify expectations on the mock mockService.AssertExpectations(t) } @@ -62,136 +50,92 @@ func TestPIDZero_StartShutdownManager_TriggersShutdown(t *testing.T) { // cleans up its listener goroutines when the main context is cancelled. func TestPIDZero_StartShutdownManager_ContextCancel(t *testing.T) { t.Parallel() + synctest.Test(t, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() - // Create a context with cancel for cleanup - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() // Ensure cleanup - - // Create mock service that implements ShutdownSender - mockService := mocks.NewMockRunnableWithShutdownSender() - shutdownChan := make(chan struct{}) // Unbuffered is fine, won't be used - - // Expect GetShutdownTrigger to be called, but no signal sent - mockService.On("GetShutdownTrigger").Return(shutdownChan).Once() - mockService.On("String").Return("mockShutdownService").Maybe() - mockService.On("Stop").Return().Maybe() + mockService := mocks.NewMockRunnableWithShutdownSender() + shutdownChan := make(chan struct{}) - // Create a supervisor with the mock runnable - pidZero, err := New(WithContext(ctx), WithRunnables(mockService)) - require.NoError(t, err) + mockService.On("GetShutdownTrigger").Return(shutdownChan).Once() + mockService.On("String").Return("mockShutdownService").Maybe() + mockService.On("Stop").Return().Maybe() - // Start the shutdown manager using wg.Go - pidZero.wg.Go(pidZero.startShutdownManager) + pidZero, err := New(WithContext(ctx), WithRunnables(mockService)) + require.NoError(t, err) - // Give the manager a moment to start its internal listener goroutine - time.Sleep(100 * time.Millisecond) // Increased sleep duration + pidZero.wg.Go(pidZero.startShutdownManager) - // Cancel the main context - cancel() + cancel() + synctest.Wait() - // Wait for the startShutdownManager goroutine (and its listeners) to finish - waitChan := make(chan struct{}) - go func() { - pidZero.wg.Wait() - close(waitChan) - }() - - select { - case <-waitChan: - // WaitGroup finished cleanly - case <-time.After(1 * time.Second): - t.Fatal("WaitGroup did not finish within the timeout after context cancellation") - } - - // Verify expectations on the mock - mockService.AssertExpectations(t) + mockService.AssertExpectations(t) + }) } // TestPIDZero_StartShutdownManager_NoSenders verifies that the manager // starts and stops cleanly even if no runnables implement ShutdownSender. func TestPIDZero_StartShutdownManager_NoSenders(t *testing.T) { t.Parallel() + synctest.Test(t, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() - // Create a context with cancel for cleanup - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() // Ensure cleanup - - // Create a mock service that does *not* implement ShutdownSender - nonSenderRunnable := mocks.NewMockRunnable() - nonSenderRunnable.On("Run", mock.Anything).Return(nil).Maybe() - nonSenderRunnable.On("Stop").Maybe() - nonSenderRunnable.On("String").Return("simpleRunnable").Maybe() + nonSenderRunnable := mocks.NewMockRunnable() + nonSenderRunnable.On("Run", mock.Anything).Return(nil).Maybe() + nonSenderRunnable.On("Stop").Maybe() + nonSenderRunnable.On("String").Return("simpleRunnable").Maybe() - // Create a supervisor with the non-sender runnable - pidZero, err := New(WithContext(ctx), WithRunnables(nonSenderRunnable)) - require.NoError(t, err) - - // Start the shutdown manager using wg.Go - pidZero.wg.Go(pidZero.startShutdownManager) - - // Give the manager a moment to start - time.Sleep(50 * time.Millisecond) - - // Cancel the main context - cancel() + pidZero, err := New(WithContext(ctx), WithRunnables(nonSenderRunnable)) + require.NoError(t, err) - // Wait for the startShutdownManager goroutine to finish - waitChan := make(chan struct{}) - go func() { - pidZero.wg.Wait() - close(waitChan) - }() + pidZero.wg.Go(pidZero.startShutdownManager) - select { - case <-waitChan: - // WaitGroup finished cleanly - case <-time.After(1 * time.Second): - t.Fatal("WaitGroup did not finish within the timeout when no senders were present") - } + cancel() + synctest.Wait() - // No ShutdownSender specific mock expectations to verify - nonSenderRunnable.AssertExpectations(t) // Assert expectations for Run/Stop/String if set + nonSenderRunnable.AssertExpectations(t) + }) } // TestPIDZero_Shutdown_WithTimeoutNotExceeded verifies that shutdown completes // successfully when runnables finish within the configured timeout. +// Cannot use synctest: calls pidZero.Run() which uses signal.Notify. func TestPIDZero_Shutdown_WithTimeoutNotExceeded(t *testing.T) { t.Parallel() - // Create a blocking channel that will be closed when Stop is called stopCalled := make(chan struct{}) - // Create a mock service that blocks in Run until Stop is called runnable := mocks.NewMockRunnable() + runStarted := make(chan struct{}) - // Configure Run to block until Stop is called runnable.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { - // Block until stopCalled is closed + close(runStarted) <-stopCalled }) - // Configure Stop to unblock the Run method runnable.On("Stop").Once().Run(func(args mock.Arguments) { close(stopCalled) }) runnable.On("String").Return("blockingRunnable").Maybe() - // Create supervisor with reasonable timeout pidZero, err := New( WithRunnables(runnable), WithShutdownTimeout(2*time.Second), ) require.NoError(t, err) - // Run the supervisor execDone := make(chan error, 1) go func() { execDone <- pidZero.Run() }() - // Let Run method start - time.Sleep(200 * time.Millisecond) + select { + case <-runStarted: + case <-time.After(time.Second): + t.Fatal("Run did not start in time") + } shutdownStart := time.Now() pidZero.Shutdown() @@ -212,52 +156,46 @@ func TestPIDZero_Shutdown_WithTimeoutNotExceeded(t *testing.T) { // TestPIDZero_Shutdown_WithTimeoutExceeded verifies that shutdown still completes // but logs a warning when the timeout is exceeded by goroutines that don't stop timely. +// Cannot use synctest: calls pidZero.Run() which uses signal.Notify. func TestPIDZero_Shutdown_WithTimeoutExceeded(t *testing.T) { t.Parallel() - // Create a blocking channel that will NOT be closed by Stop - // to simulate a runnable that doesn't terminate quickly stopCalled := make(chan struct{}) shutdownComplete := make(chan struct{}) - // Create a mock service that blocks in Run indefinitely runnable := mocks.NewMockRunnable() - // Configure Run to block indefinitely + runStarted := make(chan struct{}) runnable.On("Run", mock.Anything).Return(nil).Run(func(args mock.Arguments) { - // Block until stopCalled is closed (which won't happen in this test) + close(runStarted) select { case <-stopCalled: - // This won't happen case <-shutdownComplete: - // This will happen after shutdown completes - // We need this to prevent the goroutine from leaking } }) - // Configure Stop to NOT unblock the Run method - // This simulates a runnable that's slow to finish runnable.On("Stop").Once().Run(func(args mock.Arguments) { - // Do not close stopCalled channel, simulating a stuck runnable - time.Sleep(50 * time.Millisecond) // Fast return from Stop itself + time.Sleep(50 * time.Millisecond) }) runnable.On("String").Return("stuckRunnable").Maybe() - // Create supervisor with very short timeout pidZero, err := New( WithRunnables(runnable), - WithShutdownTimeout(200*time.Millisecond), // Shorter than our test duration + WithShutdownTimeout(200*time.Millisecond), ) require.NoError(t, err) - // Run the supervisor execDone := make(chan error, 1) go func() { execDone <- pidZero.Run() }() - time.Sleep(200 * time.Millisecond) + select { + case <-runStarted: + case <-time.After(time.Second): + t.Fatal("Run did not start in time") + } shutdownStart := time.Now() shutdownDone := make(chan struct{}) @@ -278,6 +216,6 @@ func TestPIDZero_Shutdown_WithTimeoutExceeded(t *testing.T) { t.Fatal("Shutdown did not complete despite timeout") } - close(shutdownComplete) // Prevent goroutine leak + close(shutdownComplete) runnable.AssertExpectations(t) } diff --git a/supervisor/state_deduplication_test.go b/supervisor/state_deduplication_test.go index 54443ee..2f59e8e 100644 --- a/supervisor/state_deduplication_test.go +++ b/supervisor/state_deduplication_test.go @@ -2,9 +2,8 @@ package supervisor import ( "context" - "maps" "testing" - "time" + "testing/synctest" "github.com/robbyt/go-supervisor/runnables/mocks" "github.com/stretchr/testify/assert" @@ -17,131 +16,98 @@ import ( // multiple times in a row through the state channel. func TestStateDeduplication(t *testing.T) { t.Parallel() - - // Create a context with a suitable timeout - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - // Create a channel for sending state updates - stateChan := make(chan string, 10) - runnable := mocks.NewMockRunnableWithStateable() - runnable.On("String").Return("test-runnable") - runnable.On("GetStateChan", mock.Anything).Return(stateChan) - runnable.On("GetState").Return("initial") - - // Create a new supervisor with our test runnable - pidZero, err := New(WithContext(ctx), WithRunnables(runnable)) - require.NoError(t, err) - - // Track the broadcasts that occur - broadcasts := []StateMap{} - broadcastChan := make(chan StateMap, 10) - unsubscribe := pidZero.AddStateSubscriber(broadcastChan) - defer unsubscribe() - - // Collect broadcasts in a background goroutine - collectDone := make(chan struct{}) - go func() { - defer close(collectDone) - for { - select { - case stateMap, ok := <-broadcastChan: - if !ok { + synctest.Test(t, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + stateChan := make(chan string, 10) + runnable := mocks.NewMockRunnableWithStateable() + runnable.On("String").Return("test-runnable") + runnable.On("GetStateChan", mock.Anything).Return(stateChan) + runnable.On("GetState").Return("initial") + + pidZero, err := New(WithContext(ctx), WithRunnables(runnable)) + require.NoError(t, err) + + broadcastChan := make(chan StateMap, 10) + unsubscribe := pidZero.AddStateSubscriber(broadcastChan) + defer unsubscribe() + + statesReceived := make(map[string]int) + collectDone := make(chan struct{}) + go func() { + defer close(collectDone) + for { + select { + case stateMap, ok := <-broadcastChan: + if !ok { + return + } + if state, ok := stateMap[runnable.String()]; ok { + statesReceived[state]++ + } + t.Logf("Received broadcast: %+v", stateMap) + case <-ctx.Done(): return } - // Copy the map to avoid issues with concurrent modification - copy := make(StateMap) - maps.Copy(copy, stateMap) - broadcasts = append(broadcasts, copy) - t.Logf("Received broadcast: %+v", copy) - case <-ctx.Done(): - return } - } - }() - - // Store the initial state to match production behavior - pidZero.stateMap.Store(runnable, "initial") - - // Start the state monitor - pidZero.wg.Add(1) - go pidZero.startStateMonitor() - - // Send the initial state to be discarded as per implementation - t.Log("Sending 'initial' to be discarded") - stateChan <- "initial" - time.Sleep(50 * time.Millisecond) - - // Test sequence: - // 1. Send "running" once - should trigger broadcast - // 2. Send "running" twice more - should be ignored as duplicates - // 3. Send "stopped" - should trigger broadcast - // 4. Send "stopped" again - should be ignored as duplicate - // 5. Send "error" - should trigger broadcast - - // First state change - t.Log("Sending 'running' state") - runnable.On("GetState").Return("running") - stateChan <- "running" - - // Send duplicate states - should be ignored - t.Log("Sending 'running' state again (should be ignored)") - stateChan <- "running" - - t.Log("Sending 'running' state a third time (should be ignored)") - stateChan <- "running" - - // Second state change - t.Log("Sending 'stopped' state") - runnable.On("GetState").Return("stopped") - stateChan <- "stopped" - - // Another duplicate - should be ignored - t.Log("Sending 'stopped' state again (should be ignored)") - stateChan <- "stopped" - - // Third state change - t.Log("Sending 'error' state") - runnable.On("GetState").Return("error") - stateChan <- "error" - time.Sleep(100 * time.Millisecond) - - // Clean up and wait for collection to complete - cancel() - unsubscribe() - close(broadcastChan) - <-collectDone - - // Log final state for debugging - t.Log("All broadcasts received:") - for i, b := range broadcasts { - t.Logf(" %d: %+v", i, b) - } - - // Count number of each state broadcast received - statesReceived := make(map[string]int) - for _, broadcast := range broadcasts { - // Look for the state of our test runnable - if state, ok := broadcast[runnable.String()]; ok { - statesReceived[state]++ - } - } - - // Log state counts - t.Logf("State broadcast counts: %+v", statesReceived) - - // We should have unique state broadcasts (one each) - // for running, stopped, and error states - assert.Equal( - t, 1, statesReceived["running"], - "Should receive exactly one 'running' state broadcast", - ) - assert.Equal( - t, 1, statesReceived["stopped"], - "Should receive exactly one 'stopped' state broadcast", - ) - assert.Equal( - t, 1, statesReceived["error"], - "Should receive exactly one 'error' state broadcast", - ) + }() + + pidZero.stateMap.Store(runnable, "initial") + pidZero.wg.Go(pidZero.startStateMonitor) + + // Test sequence: + // 1. Send "initial" - should be discarded (already captured in startRunnable) + // 2. Send "running" once - should trigger broadcast + // 3. Send "running" twice more - should be ignored as duplicates + // 4. Send "stopped" - should trigger broadcast + // 5. Send "stopped" again - should be ignored as duplicate + // 6. Send "error" - should trigger broadcast + + t.Log("Sending 'initial' to be discarded") + stateChan <- "initial" + + t.Log("Sending 'running' state") + runnable.On("GetState").Return("running") + stateChan <- "running" + + t.Log("Sending 'running' state again (should be ignored)") + stateChan <- "running" + + t.Log("Sending 'running' state a third time (should be ignored)") + stateChan <- "running" + + t.Log("Sending 'stopped' state") + runnable.On("GetState").Return("stopped") + stateChan <- "stopped" + + t.Log("Sending 'stopped' state again (should be ignored)") + stateChan <- "stopped" + + t.Log("Sending 'error' state") + runnable.On("GetState").Return("error") + stateChan <- "error" + + synctest.Wait() + + cancel() + unsubscribe() + close(broadcastChan) + <-collectDone + + t.Logf("State broadcast counts: %+v", statesReceived) + + assert.Equal( + t, 1, statesReceived["running"], + "Should receive exactly one 'running' state broadcast", + ) + assert.Equal( + t, 1, statesReceived["stopped"], + "Should receive exactly one 'stopped' state broadcast", + ) + assert.Equal( + t, 1, statesReceived["error"], + "Should receive exactly one 'error' state broadcast", + ) + }) } diff --git a/supervisor/state_monitoring_test.go b/supervisor/state_monitoring_test.go index f11e969..d1b1045 100644 --- a/supervisor/state_monitoring_test.go +++ b/supervisor/state_monitoring_test.go @@ -3,6 +3,7 @@ package supervisor import ( "context" "testing" + "testing/synctest" "time" "github.com/robbyt/go-supervisor/runnables/mocks" @@ -12,28 +13,25 @@ import ( ) // TestPIDZero_StartStateMonitor tests that the state monitor is started for stateable runnables. +// This test calls pid0.Run() which uses signal.Notify, so it cannot use synctest. func TestPIDZero_StartStateMonitor(t *testing.T) { t.Parallel() - // Create a mock stateable runnable mockStateable := mocks.NewMockRunnableWithStateable() mockStateable.On("String").Return("stateable-runnable").Maybe() mockStateable.On("Run", mock.Anything).Return(nil) mockStateable.On("Stop").Once() - mockStateable.On("GetState").Return("initial").Once() // Initial state - mockStateable.On("GetState").Return("running").Maybe() // Called during shutdown + mockStateable.On("GetState").Return("initial").Once() + mockStateable.On("GetState").Return("running").Maybe() - stateChan := make(chan string, 5) // Buffered to prevent blocking + stateChan := make(chan string, 5) mockStateable.On("GetStateChan", mock.Anything).Return(stateChan).Once() - // Will be called during startup verification mockStateable.On("IsRunning").Return(true).Once() - // Create context with timeout to ensure test completion ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - // Create supervisor with the mock runnable pid0, err := New( WithContext(ctx), WithRunnables(mockStateable), @@ -41,41 +39,30 @@ func TestPIDZero_StartStateMonitor(t *testing.T) { ) require.NoError(t, err) - // Create a state subscriber to verify state broadcasts stateUpdates := make(chan StateMap, 5) unsubscribe := pid0.AddStateSubscriber(stateUpdates) defer unsubscribe() - // Start the supervisor in a goroutine execDone := make(chan error, 1) go func() { execDone <- pid0.Run() }() - // Allow time for initialization - time.Sleep(50 * time.Millisecond) + stateChan <- "initial" + stateChan <- "running" + stateChan <- "stopping" - // Send state updates through the channel - stateChan <- "initial" // This should be discarded as it's the initial state - stateChan <- "running" // This will be processed - stateChan <- "stopping" // Additional state change - - // Use require.Eventually to verify the state monitor receives and broadcasts states require.Eventually(t, func() bool { - // Check if we have received at least one state update select { case stateMap := <-stateUpdates: - // We don't check for specific values, just that broadcasts are happening return stateMap["stateable-runnable"] != "" default: return false } }, 500*time.Millisecond, 50*time.Millisecond, "No state updates received") - // Cancel the context to shut down the supervisor cancel() - // Verify the supervisor shuts down cleanly select { case err := <-execDone: require.NoError(t, err) @@ -83,82 +70,57 @@ func TestPIDZero_StartStateMonitor(t *testing.T) { t.Fatal("Supervisor did not shut down in time") } - // Verify expectations mockStateable.AssertExpectations(t) } // TestPIDZero_SubscribeStateChanges tests the SubscribeStateChanges functionality. func TestPIDZero_SubscribeStateChanges(t *testing.T) { t.Parallel() + synctest.Test(t, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + mockService := mocks.NewMockRunnableWithStateable() + stateChan := make(chan string, 2) + mockService.On("GetStateChan", mock.Anything).Return(stateChan).Once() + mockService.On("String").Return("mock-service").Maybe() + mockService.On("Run", mock.Anything).Return(nil).Maybe() + mockService.On("Stop").Maybe() + mockService.On("GetState").Return("initial").Maybe() + mockService.On("IsRunning").Return(true).Maybe() + + pid0, err := New(WithContext(ctx), WithRunnables(mockService)) + require.NoError(t, err) - // Create a context with cancel for cleanup - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Create mock services that implement Stateable - mockService := mocks.NewMockRunnableWithStateable() - stateChan := make(chan string, 2) - mockService.On("GetStateChan", mock.Anything).Return(stateChan).Once() - mockService.On("String").Return("mock-service").Maybe() - mockService.On("Run", mock.Anything).Return(nil).Maybe() - mockService.On("Stop").Maybe() - mockService.On("GetState").Return("initial").Maybe() - mockService.On("IsRunning").Return(true).Maybe() - - // Create a supervisor with the mock runnable - pid0, err := New(WithContext(ctx), WithRunnables(mockService)) - require.NoError(t, err) - - // Store initial states manually in stateMap - pid0.stateMap.Store(mockService, "initial") - - // Subscribe to state changes - subCtx, subCancel := context.WithCancel(context.Background()) - defer subCancel() - stateMapChan := pid0.SubscribeStateChanges(subCtx) - - // Manually call startStateMonitor to avoid the full Run sequence - pid0.wg.Add(1) - go pid0.startStateMonitor() + pid0.stateMap.Store(mockService, "initial") - // Give state monitor a moment to start - time.Sleep(100 * time.Millisecond) + subCtx, subCancel := context.WithCancel(t.Context()) + defer subCancel() + stateMapChan := pid0.SubscribeStateChanges(subCtx) - // Send an initial state update - stateChan <- "initial" // Should be discarded + pid0.wg.Go(pid0.startStateMonitor) + synctest.Wait() - // Send another state update - stateChan <- "running" // Should be broadcast - time.Sleep(100 * time.Millisecond) + stateChan <- "initial" + stateChan <- "running" - // Manually update state and trigger broadcast to ensure it happens - pid0.stateMap.Store(mockService, "running") - pid0.broadcastState() - time.Sleep(100 * time.Millisecond) + pid0.stateMap.Store(mockService, "running") + pid0.broadcastState() - // Verify we receive state updates - var stateMap StateMap - var foundRunning bool - timeout := time.After(500 * time.Millisecond) + synctest.Wait() - // Loop until we find the update we want or time out - for !foundRunning { - select { - case stateMap = <-stateMapChan: + var foundRunning bool + for range len(stateMapChan) { + stateMap := <-stateMapChan if val, ok := stateMap["mock-service"]; ok && val == "running" { foundRunning = true + break } - case <-timeout: - t.Fatal("Did not receive running state update in time") } - } - - assert.True(t, foundRunning, "Should have received a state map with running state") + assert.True(t, foundRunning, "Should have received a state map with running state") - // Cancel the context to clean up goroutines - cancel() - time.Sleep(50 * time.Millisecond) + cancel() - // Verify expectations - mockService.AssertExpectations(t) + mockService.AssertExpectations(t) + }) }