diff --git a/ee/desktop/runner/runner.go b/ee/desktop/runner/runner.go index 55d33e789..47e8d1db6 100644 --- a/ee/desktop/runner/runner.go +++ b/ee/desktop/runner/runner.go @@ -31,6 +31,7 @@ import ( "github.com/kolide/launcher/ee/desktop/user/notify" "github.com/kolide/launcher/ee/ui/assets" "github.com/kolide/launcher/pkg/backoff" + "github.com/kolide/launcher/pkg/rungroup" "github.com/kolide/launcher/pkg/traces" "github.com/shirou/gopsutil/v3/process" "golang.org/x/exp/maps" @@ -159,7 +160,7 @@ func New(k types.Knapsack, messenger runnerserver.Messenger, opts ...desktopUser updateInterval: k.DesktopUpdateInterval(), menuRefreshInterval: k.DesktopMenuRefreshInterval(), procsWg: &sync.WaitGroup{}, - interruptTimeout: time.Second * 10, + interruptTimeout: time.Second * 5, hostname: k.KolideServerURL(), usersFilesRoot: agent.TempPath("kolide-desktop"), processSpawningEnabled: k.DesktopEnabled(), @@ -259,22 +260,34 @@ func (r *DesktopUsersProcessesRunner) Interrupt(_ error) { // Tell the execute loop to stop checking, and exit r.interrupt <- struct{}{} - // Kill any desktop processes that may exist - r.killDesktopProcesses() + // The timeout for `Interrupt` is the desktop process interrupt timeout (r.interruptTimeout) + // plus a small buffer for killing processes that couldn't be shut down gracefully during r.interuptTimeout. + shutdownTimeout := r.interruptTimeout + 3*time.Second + // This timeout for `Interrupt` should not be larger than rungroup.interruptTimeout. + if shutdownTimeout > rungroup.InterruptTimeout { + shutdownTimeout = rungroup.InterruptTimeout + } - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() + // Kill any desktop processes that may exist + r.killDesktopProcesses(ctx) + if err := r.runnerServer.Shutdown(ctx); err != nil { r.slogger.Log(ctx, slog.LevelError, "shutting down monitor server", "err", err, ) } + + r.slogger.Log(ctx, slog.LevelInfo, + "desktop runner shutdown complete", + ) } // killDesktopProcesses kills any existing desktop processes -func (r *DesktopUsersProcessesRunner) killDesktopProcesses() { +func (r *DesktopUsersProcessesRunner) killDesktopProcesses(ctx context.Context) { wgDone := make(chan struct{}) go func() { defer close(wgDone) @@ -287,8 +300,8 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() { r.runnerServer.DeRegisterClient(uid) client := client.New(r.userServerAuthToken, proc.socketPath) - if err := client.Shutdown(); err != nil { - r.slogger.Log(context.TODO(), slog.LevelError, + if err := client.Shutdown(ctx); err != nil { + r.slogger.Log(ctx, slog.LevelError, "sending shutdown command to user desktop process", "uid", uid, "pid", proc.Process.Pid, @@ -303,7 +316,7 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() { select { case <-wgDone: if shutdownRequestCount > 0 { - r.slogger.Log(context.TODO(), slog.LevelDebug, + r.slogger.Log(ctx, slog.LevelDebug, "successfully completed desktop process shutdown requests", "count", shutdownRequestCount, ) @@ -312,7 +325,7 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() { maps.Clear(r.uidProcs) return case <-time.After(r.interruptTimeout): - r.slogger.Log(context.TODO(), slog.LevelError, + r.slogger.Log(ctx, slog.LevelError, "timeout waiting for desktop processes to exit, now killing", ) @@ -321,7 +334,7 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() { continue } if err := processRecord.Process.Kill(); err != nil { - r.slogger.Log(context.TODO(), slog.LevelError, + r.slogger.Log(ctx, slog.LevelError, "killing desktop process", "uid", uid, "pid", processRecord.Process.Pid, @@ -331,6 +344,10 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() { } } } + + r.slogger.Log(ctx, slog.LevelInfo, + "killed user desktop processes", + ) } func (r *DesktopUsersProcessesRunner) SendNotification(n notify.Notification) error { @@ -516,7 +533,7 @@ func (r *DesktopUsersProcessesRunner) writeDefaultMenuTemplateFile() { func (r *DesktopUsersProcessesRunner) runConsoleUserDesktop() error { if !r.processSpawningEnabled { // Desktop is disabled, kill any existing desktop user processes - r.killDesktopProcesses() + r.killDesktopProcesses(context.Background()) return nil } @@ -909,6 +926,6 @@ func (r *DesktopUsersProcessesRunner) checkOsUpdate() { "new", currentOsVersion, ) r.osVersion = currentOsVersion - r.killDesktopProcesses() + r.killDesktopProcesses(context.Background()) } } diff --git a/ee/desktop/user/client/client.go b/ee/desktop/user/client/client.go index 446924413..8fe1641b6 100644 --- a/ee/desktop/user/client/client.go +++ b/ee/desktop/user/client/client.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -42,8 +43,8 @@ func New(authToken, socketPath string) client { return client } -func (c *client) Shutdown() error { - return c.get("shutdown") +func (c *client) Shutdown(ctx context.Context) error { + return c.getWithContext(ctx, "shutdown") } func (c *client) Ping() error { @@ -95,3 +96,25 @@ func (c *client) get(path string) error { return nil } + +func (c *client) getWithContext(ctx context.Context, path string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://unix/%s", path), nil) + if err != nil { + return fmt.Errorf("creating request with context: %w", err) + } + + resp, err := c.base.Do(req) + if err != nil { + return fmt.Errorf("making request: %w", err) + } + + if resp.Body != nil { + resp.Body.Close() + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return nil +} diff --git a/pkg/rungroup/rungroup.go b/pkg/rungroup/rungroup.go index b674d4ac9..05ee46d32 100644 --- a/pkg/rungroup/rungroup.go +++ b/pkg/rungroup/rungroup.go @@ -33,8 +33,8 @@ type ( ) const ( - interruptTimeout = 5 * time.Second // How long for all actors to return from their `interrupt` function - executeReturnTimeout = 5 * time.Second // After interrupted, how long for all actors to exit their `execute` functions + InterruptTimeout = 10 * time.Second // How long for all actors to return from their `interrupt` function + executeReturnTimeout = 5 * time.Second // After interrupted, how long for all actors to exit their `execute` functions ) func NewRunGroup(slogger *slog.Logger) *Group { @@ -107,7 +107,7 @@ func (g *Group) Run() error { }(a) } - interruptCtx, interruptCancel := context.WithTimeout(context.Background(), interruptTimeout) + interruptCtx, interruptCancel := context.WithTimeout(context.Background(), InterruptTimeout) defer interruptCancel() // Wait for interrupts to complete, but only until we hit our interruptCtx timeout diff --git a/pkg/rungroup/rungroup_test.go b/pkg/rungroup/rungroup_test.go index 9796c007b..e011f9f35 100644 --- a/pkg/rungroup/rungroup_test.go +++ b/pkg/rungroup/rungroup_test.go @@ -60,7 +60,7 @@ func TestRun_MultipleActors(t *testing.T) { }() // 1 second before interrupt, waiting for interrupt, and waiting for execute return, plus a little buffer - runDuration := 1*time.Second + interruptTimeout + executeReturnTimeout + 1*time.Second + runDuration := 1*time.Second + InterruptTimeout + executeReturnTimeout + 1*time.Second interruptCheckTimer := time.NewTicker(runDuration) defer interruptCheckTimer.Stop() @@ -119,7 +119,7 @@ func TestRun_MultipleActors_InterruptTimeout(t *testing.T) { <-blockingActorInterrupt return nil }, func(error) { - time.Sleep(4 * interruptTimeout) + time.Sleep(4 * InterruptTimeout) groupReceivedInterrupts <- struct{}{} blockingActorInterrupt <- struct{}{} }) @@ -132,7 +132,7 @@ func TestRun_MultipleActors_InterruptTimeout(t *testing.T) { }() // 1 second before interrupt, waiting for interrupt, and waiting for execute return, plus a little buffer - runDuration := 1*time.Second + interruptTimeout + executeReturnTimeout + 1*time.Second + runDuration := 1*time.Second + InterruptTimeout + executeReturnTimeout + 1*time.Second interruptCheckTimer := time.NewTicker(runDuration) defer interruptCheckTimer.Stop() @@ -208,7 +208,7 @@ func TestRun_MultipleActors_ExecuteReturnTimeout(t *testing.T) { }() // 1 second before interrupt, waiting for interrupt, and waiting for execute return, plus a little buffer - runDuration := 1*time.Second + interruptTimeout + executeReturnTimeout + 1*time.Second + runDuration := 1*time.Second + InterruptTimeout + executeReturnTimeout + 1*time.Second interruptCheckTimer := time.NewTicker(runDuration) defer interruptCheckTimer.Stop()