Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure desktop runner shuts down within rungroup interrupt timeout, and log shutdown completion #1668

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions ee/desktop/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/kolide/launcher/ee/desktop/user/notify"
"github.com/kolide/launcher/ee/ui/assets"
"github.com/kolide/launcher/pkg/backoff"
"github.com/kolide/launcher/pkg/rungroup"
"github.com/kolide/launcher/pkg/traces"
"github.com/shirou/gopsutil/v3/process"
"golang.org/x/exp/maps"
Expand Down Expand Up @@ -159,7 +160,7 @@ func New(k types.Knapsack, messenger runnerserver.Messenger, opts ...desktopUser
updateInterval: k.DesktopUpdateInterval(),
menuRefreshInterval: k.DesktopMenuRefreshInterval(),
procsWg: &sync.WaitGroup{},
interruptTimeout: time.Second * 10,
interruptTimeout: time.Second * 5,
hostname: k.KolideServerURL(),
usersFilesRoot: agent.TempPath("kolide-desktop"),
processSpawningEnabled: k.DesktopEnabled(),
Expand Down Expand Up @@ -259,22 +260,34 @@ func (r *DesktopUsersProcessesRunner) Interrupt(_ error) {
// Tell the execute loop to stop checking, and exit
r.interrupt <- struct{}{}

// Kill any desktop processes that may exist
r.killDesktopProcesses()
// The timeout for `Interrupt` is the desktop process interrupt timeout (r.interruptTimeout)
// plus a small buffer for killing processes that couldn't be shut down gracefully during r.interuptTimeout.
shutdownTimeout := r.interruptTimeout + 3*time.Second
// This timeout for `Interrupt` should not be larger than rungroup.interruptTimeout.
if shutdownTimeout > rungroup.InterruptTimeout {
shutdownTimeout = rungroup.InterruptTimeout
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()

// Kill any desktop processes that may exist
r.killDesktopProcesses(ctx)

if err := r.runnerServer.Shutdown(ctx); err != nil {
r.slogger.Log(ctx, slog.LevelError,
"shutting down monitor server",
"err", err,
)
}

r.slogger.Log(ctx, slog.LevelInfo,
"desktop runner shutdown complete",
)
}

// killDesktopProcesses kills any existing desktop processes
func (r *DesktopUsersProcessesRunner) killDesktopProcesses() {
func (r *DesktopUsersProcessesRunner) killDesktopProcesses(ctx context.Context) {
wgDone := make(chan struct{})
go func() {
defer close(wgDone)
Expand All @@ -287,8 +300,8 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() {
r.runnerServer.DeRegisterClient(uid)

client := client.New(r.userServerAuthToken, proc.socketPath)
if err := client.Shutdown(); err != nil {
r.slogger.Log(context.TODO(), slog.LevelError,
if err := client.Shutdown(ctx); err != nil {
r.slogger.Log(ctx, slog.LevelError,
"sending shutdown command to user desktop process",
"uid", uid,
"pid", proc.Process.Pid,
Expand All @@ -303,7 +316,7 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() {
select {
case <-wgDone:
if shutdownRequestCount > 0 {
r.slogger.Log(context.TODO(), slog.LevelDebug,
r.slogger.Log(ctx, slog.LevelDebug,
"successfully completed desktop process shutdown requests",
"count", shutdownRequestCount,
)
Expand All @@ -312,7 +325,7 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() {
maps.Clear(r.uidProcs)
return
case <-time.After(r.interruptTimeout):
r.slogger.Log(context.TODO(), slog.LevelError,
r.slogger.Log(ctx, slog.LevelError,
"timeout waiting for desktop processes to exit, now killing",
)

Expand All @@ -321,7 +334,7 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() {
continue
}
if err := processRecord.Process.Kill(); err != nil {
r.slogger.Log(context.TODO(), slog.LevelError,
r.slogger.Log(ctx, slog.LevelError,
"killing desktop process",
"uid", uid,
"pid", processRecord.Process.Pid,
Expand All @@ -331,6 +344,10 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() {
}
}
}

r.slogger.Log(ctx, slog.LevelInfo,
"killed user desktop processes",
)
}

func (r *DesktopUsersProcessesRunner) SendNotification(n notify.Notification) error {
Expand Down Expand Up @@ -516,7 +533,7 @@ func (r *DesktopUsersProcessesRunner) writeDefaultMenuTemplateFile() {
func (r *DesktopUsersProcessesRunner) runConsoleUserDesktop() error {
if !r.processSpawningEnabled {
// Desktop is disabled, kill any existing desktop user processes
r.killDesktopProcesses()
r.killDesktopProcesses(context.Background())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentionally starting a new contenxt? (or should it be context.TODO())

return nil
}

Expand Down Expand Up @@ -909,6 +926,6 @@ func (r *DesktopUsersProcessesRunner) checkOsUpdate() {
"new", currentOsVersion,
)
r.osVersion = currentOsVersion
r.killDesktopProcesses()
r.killDesktopProcesses(context.Background())
}
}
27 changes: 25 additions & 2 deletions ee/desktop/user/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -42,8 +43,8 @@ func New(authToken, socketPath string) client {
return client
}

func (c *client) Shutdown() error {
return c.get("shutdown")
func (c *client) Shutdown(ctx context.Context) error {
return c.getWithContext(ctx, "shutdown")
}

func (c *client) Ping() error {
Expand Down Expand Up @@ -95,3 +96,25 @@ func (c *client) get(path string) error {

return nil
}

func (c *client) getWithContext(ctx context.Context, path string) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://unix/%s", path), nil)
if err != nil {
return fmt.Errorf("creating request with context: %w", err)
}

resp, err := c.base.Do(req)
if err != nil {
return fmt.Errorf("making request: %w", err)
}

if resp.Body != nil {
resp.Body.Close()
}

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

return nil
}
6 changes: 3 additions & 3 deletions pkg/rungroup/rungroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ type (
)

const (
interruptTimeout = 5 * time.Second // How long for all actors to return from their `interrupt` function
executeReturnTimeout = 5 * time.Second // After interrupted, how long for all actors to exit their `execute` functions
InterruptTimeout = 10 * time.Second // How long for all actors to return from their `interrupt` function
executeReturnTimeout = 5 * time.Second // After interrupted, how long for all actors to exit their `execute` functions
)

func NewRunGroup(slogger *slog.Logger) *Group {
Expand Down Expand Up @@ -107,7 +107,7 @@ func (g *Group) Run() error {
}(a)
}

interruptCtx, interruptCancel := context.WithTimeout(context.Background(), interruptTimeout)
interruptCtx, interruptCancel := context.WithTimeout(context.Background(), InterruptTimeout)
defer interruptCancel()

// Wait for interrupts to complete, but only until we hit our interruptCtx timeout
Expand Down
8 changes: 4 additions & 4 deletions pkg/rungroup/rungroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestRun_MultipleActors(t *testing.T) {
}()

// 1 second before interrupt, waiting for interrupt, and waiting for execute return, plus a little buffer
runDuration := 1*time.Second + interruptTimeout + executeReturnTimeout + 1*time.Second
runDuration := 1*time.Second + InterruptTimeout + executeReturnTimeout + 1*time.Second
interruptCheckTimer := time.NewTicker(runDuration)
defer interruptCheckTimer.Stop()

Expand Down Expand Up @@ -119,7 +119,7 @@ func TestRun_MultipleActors_InterruptTimeout(t *testing.T) {
<-blockingActorInterrupt
return nil
}, func(error) {
time.Sleep(4 * interruptTimeout)
time.Sleep(4 * InterruptTimeout)
groupReceivedInterrupts <- struct{}{}
blockingActorInterrupt <- struct{}{}
})
Expand All @@ -132,7 +132,7 @@ func TestRun_MultipleActors_InterruptTimeout(t *testing.T) {
}()

// 1 second before interrupt, waiting for interrupt, and waiting for execute return, plus a little buffer
runDuration := 1*time.Second + interruptTimeout + executeReturnTimeout + 1*time.Second
runDuration := 1*time.Second + InterruptTimeout + executeReturnTimeout + 1*time.Second
interruptCheckTimer := time.NewTicker(runDuration)
defer interruptCheckTimer.Stop()

Expand Down Expand Up @@ -208,7 +208,7 @@ func TestRun_MultipleActors_ExecuteReturnTimeout(t *testing.T) {
}()

// 1 second before interrupt, waiting for interrupt, and waiting for execute return, plus a little buffer
runDuration := 1*time.Second + interruptTimeout + executeReturnTimeout + 1*time.Second
runDuration := 1*time.Second + InterruptTimeout + executeReturnTimeout + 1*time.Second
interruptCheckTimer := time.NewTicker(runDuration)
defer interruptCheckTimer.Stop()

Expand Down
Loading