diff --git a/ee/desktop/runner/runner.go b/ee/desktop/runner/runner.go index dc961cc29..8d79be12d 100644 --- a/ee/desktop/runner/runner.go +++ b/ee/desktop/runner/runner.go @@ -107,6 +107,7 @@ type DesktopUsersProcessesRunner struct { // menuRefreshInterval is the interval on which the desktop menu will be refreshed menuRefreshInterval time.Duration interrupt chan struct{} + interrupted bool // uidProcs is a map of uid to desktop process uidProcs map[string]processRecord // procsWg is a WaitGroup to wait for all desktop processes to finish during an interrupt @@ -222,6 +223,13 @@ func (r *DesktopUsersProcessesRunner) Execute() error { // Interrupt stops creating launcher desktop processes and kills any existing ones. // It also signals the execute loop to exit, so new desktop processes cease to spawn. func (r *DesktopUsersProcessesRunner) Interrupt(_ error) { + // Only perform shutdown tasks on first call to interrupt -- no need to repeat on potential extra calls. + if r.interrupted { + return + } + + r.interrupted = true + // Tell the execute loop to stop checking, and exit r.interrupt <- struct{}{} diff --git a/ee/desktop/runner/runner_test.go b/ee/desktop/runner/runner_test.go index 36b1f54e6..2b934cb04 100644 --- a/ee/desktop/runner/runner_test.go +++ b/ee/desktop/runner/runner_test.go @@ -145,7 +145,7 @@ func TestDesktopUserProcessRunner_Execute(t *testing.T) { assert.NoError(t, r.Execute()) }() - // let is run a few interval + // let it run a few intervals time.Sleep(r.updateInterval * 3) r.Interrupt(nil) @@ -185,6 +185,34 @@ func TestDesktopUserProcessRunner_Execute(t *testing.T) { p.Process.Wait() } }) + + // Confirm we can call Interrupt multiple times without blocking + interruptComplete := make(chan struct{}) + expectedInterrupts := 3 + for i := 0; i < expectedInterrupts; i += 1 { + go func() { + r.Interrupt(nil) + interruptComplete <- struct{}{} + }() + } + + receivedInterrupts := 0 + for { + if receivedInterrupts >= expectedInterrupts { + break + } + + select { + case <-interruptComplete: + receivedInterrupts += 1 + continue + case <-time.After(5 * time.Second): + t.Errorf("could not call interrupt multiple times and return within 5 seconds -- received %d interrupts before timeout", receivedInterrupts) + t.FailNow() + } + } + + require.Equal(t, expectedInterrupts, receivedInterrupts) }) } }