Skip to content

Commit

Permalink
Add remote restart consumer to handle remote restart actions (#1948)
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany authored Nov 12, 2024
1 parent bc44bcb commit c6fe8b7
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 2 deletions.
5 changes: 5 additions & 0 deletions cmd/launcher/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/kolide/launcher/ee/control/consumers/flareconsumer"
"github.com/kolide/launcher/ee/control/consumers/keyvalueconsumer"
"github.com/kolide/launcher/ee/control/consumers/notificationconsumer"
"github.com/kolide/launcher/ee/control/consumers/remoterestartconsumer"
"github.com/kolide/launcher/ee/control/consumers/uninstallconsumer"
"github.com/kolide/launcher/ee/debug/checkups"
desktopRunner "github.com/kolide/launcher/ee/desktop/runner"
Expand Down Expand Up @@ -469,6 +470,10 @@ func runLauncher(ctx context.Context, cancel func(), multiSlogger, systemMultiSl
// register notifications consumer
actionsQueue.RegisterActor(notificationconsumer.NotificationSubsystem, notificationConsumer)

remoteRestartConsumer := remoterestartconsumer.New(k)
runGroup.Add("remoteRestart", remoteRestartConsumer.Execute, remoteRestartConsumer.Interrupt)
actionsQueue.RegisterActor(remoterestartconsumer.RemoteRestartActorType, remoteRestartConsumer)

// Set up our tracing instrumentation
authTokenConsumer := keyvalueconsumer.New(k.TokenStore())
if err := controlService.RegisterConsumer(authTokensSubsystemName, authTokenConsumer); err != nil {
Expand Down
5 changes: 3 additions & 2 deletions cmd/launcher/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/kolide/kit/env"
"github.com/kolide/kit/logutil"
"github.com/kolide/kit/version"
"github.com/kolide/launcher/ee/control/consumers/remoterestartconsumer"
"github.com/kolide/launcher/ee/tuf"
"github.com/kolide/launcher/ee/watchdog"
"github.com/kolide/launcher/pkg/contexts/ctxlog"
Expand Down Expand Up @@ -153,11 +154,11 @@ func runMain() int {
ctx = ctxlog.NewContext(ctx, logger)

if err := runLauncher(ctx, cancel, slogger, systemSlogger, opts); err != nil {
if !tuf.IsLauncherReloadNeededErr(err) {
if !tuf.IsLauncherReloadNeededErr(err) && !errors.Is(err, remoterestartconsumer.ErrRemoteRestartRequested) {
level.Debug(logger).Log("msg", "run launcher", "stack", fmt.Sprintf("%+v", err))
return 1
}
level.Debug(logger).Log("msg", "runLauncher exited to run newer version of launcher", "err", err.Error())
level.Debug(logger).Log("msg", "runLauncher exited to reload launcher", "err", err.Error())
if err := runNewerLauncherIfAvailable(ctx, slogger.Logger); err != nil {
return 1
}
Expand Down
130 changes: 130 additions & 0 deletions ee/control/consumers/remoterestartconsumer/remoterestartconsumer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package remoterestartconsumer

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"time"

"github.com/kolide/launcher/ee/agent/types"
)

const (
// RemoteRestartActorType identifies this action/actor type, which performs
// a launcher restart when requested by the control server. This actor type
// belongs to the action subsystem.
RemoteRestartActorType = "remote_restart"

// restartDelay is the delay after receiving action before triggering the restart.
// We have a delay to allow the actionqueue.
restartDelay = 15 * time.Second
)

var (
ErrRemoteRestartRequested = errors.New("need to reload launcher: remote restart requested")
)

type RemoteRestartConsumer struct {
knapsack types.Knapsack
slogger *slog.Logger
signalRestart chan error
interrupt chan struct{}
interrupted bool
}

type remoteRestartAction struct {
RunID string `json:"run_id"` // the run ID for the launcher run to restart
}

func New(knapsack types.Knapsack) *RemoteRestartConsumer {
return &RemoteRestartConsumer{
knapsack: knapsack,
slogger: knapsack.Slogger().With("component", "remote_restart_consumer"),
signalRestart: make(chan error, 1),
interrupt: make(chan struct{}, 1),
}
}

// Do implements the `actionqueue.actor` interface, and allows the actionqueue
// to pass `remote_restart` type actions to this consumer. The actionqueue validates
// that this action has not already been performed and that this action is still
// valid (i.e. not expired). `Do` additionally validates that the `run_id` given in
// the action matches the current launcher run ID.
func (r *RemoteRestartConsumer) Do(data io.Reader) error {
var restartAction remoteRestartAction

if err := json.NewDecoder(data).Decode(&restartAction); err != nil {
return fmt.Errorf("decoding restart action: %w", err)
}

// The action's run ID indicates the current `runLauncher` that should be restarted.
// If the action's run ID does not match the current run ID, we assume the restart
// has already happened and does not need to happen again.
if restartAction.RunID == "" {
r.slogger.Log(context.TODO(), slog.LevelInfo,
"received remote restart action with blank launcher run ID -- discarding",
)
return nil
}
if restartAction.RunID != r.knapsack.GetRunID() {
r.slogger.Log(context.TODO(), slog.LevelInfo,
"received remote restart action for incorrect (assuming past) launcher run ID -- discarding",
"action_run_id", restartAction.RunID,
)
return nil
}

// Perform the restart by signaling actor shutdown, but delay slightly to give
// the actionqueue a chance to process all actions and store their statuses.
go func() {
r.slogger.Log(context.TODO(), slog.LevelInfo,
"received remote restart action for current launcher run ID -- signaling for restart shortly",
"action_run_id", restartAction.RunID,
"restart_delay", restartDelay.String(),
)

select {
case <-r.interrupt:
r.slogger.Log(context.TODO(), slog.LevelDebug,
"received external interrupt before remote restart could be performed",
)
return
case <-time.After(restartDelay):
r.signalRestart <- ErrRemoteRestartRequested
r.slogger.Log(context.TODO(), slog.LevelInfo,
"signaled for restart after delay",
"action_run_id", restartAction.RunID,
)
return
}
}()

return nil
}

// Execute allows the remote restart consumer to run in the main launcher rungroup.
// It waits until it receives a remote restart action from `Do`, or until it receives
// a `Interrupt` request.
func (r *RemoteRestartConsumer) Execute() (err error) {
select {
case <-r.interrupt:
return nil
case signalRestartErr := <-r.signalRestart:
return signalRestartErr
}
}

// Interrupt allows the remote restart consumer to run in the main launcher rungroup
// and be shut down when the rungroup shuts down.
func (r *RemoteRestartConsumer) Interrupt(_ error) {
// Only perform shutdown tasks on first call to interrupt -- no need to repeat on potential extra calls.
if r.interrupted {
return
}
r.interrupted = true

r.interrupt <- struct{}{}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package remoterestartconsumer

import (
"bytes"
"encoding/json"
"errors"
"testing"
"time"

"github.com/kolide/kit/ulid"
typesmocks "github.com/kolide/launcher/ee/agent/types/mocks"
"github.com/kolide/launcher/pkg/log/multislogger"
"github.com/stretchr/testify/require"
)

func TestDo(t *testing.T) {
t.Parallel()

currentRunId := ulid.New()

mockKnapsack := typesmocks.NewKnapsack(t)
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger())
mockKnapsack.On("GetRunID").Return(currentRunId)

remoteRestarter := New(mockKnapsack)

testAction := remoteRestartAction{
RunID: currentRunId,
}
testActionRaw, err := json.Marshal(testAction)
require.NoError(t, err)

// We don't expect an error because we should process the action
require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "expected no error processing valid remote restart action")

// The restarter should delay before sending an error to `signalRestart`
require.Len(t, remoteRestarter.signalRestart, 0, "expected restarter to delay before signal for restart but channel is already has item in it")
time.Sleep(restartDelay + 2*time.Second)
require.Len(t, remoteRestarter.signalRestart, 1, "expected restarter to signal for restart but channel is empty after delay")
}

func TestDo_DoesNotSignalRestartWhenRunIDDoesNotMatch(t *testing.T) {
t.Parallel()

currentRunId := ulid.New()

mockKnapsack := typesmocks.NewKnapsack(t)
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger())
mockKnapsack.On("GetRunID").Return(currentRunId)

remoteRestarter := New(mockKnapsack)

testAction := remoteRestartAction{
RunID: ulid.New(), // run ID will not match `currentRunId`
}
testActionRaw, err := json.Marshal(testAction)
require.NoError(t, err)

// We don't expect an error because we want to discard this action
require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "should not return error for old run ID")

// The restarter should not send an error to `signalRestart`
time.Sleep(restartDelay + 2*time.Second)
require.Len(t, remoteRestarter.signalRestart, 0, "restarter should not have signaled for a restart, but channel is not empty")
}

func TestDo_DoesNotSignalRestartWhenRunIDIsEmpty(t *testing.T) {
t.Parallel()

mockKnapsack := typesmocks.NewKnapsack(t)
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger())

remoteRestarter := New(mockKnapsack)

testAction := remoteRestartAction{
RunID: "", // run ID is empty
}
testActionRaw, err := json.Marshal(testAction)
require.NoError(t, err)

// We don't expect an error because we want to discard this action
require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "should not return error for empty run ID")

// The restarter should not send an error to `signalRestart`
time.Sleep(restartDelay + 2*time.Second)
require.Len(t, remoteRestarter.signalRestart, 0, "restarter should not have signaled for a restart, but channel is not empty")
}

func TestDo_DoesNotRestartIfInterruptedDuringDelay(t *testing.T) {
t.Parallel()

currentRunId := ulid.New()

mockKnapsack := typesmocks.NewKnapsack(t)
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger())
mockKnapsack.On("GetRunID").Return(currentRunId)

remoteRestarter := New(mockKnapsack)

testAction := remoteRestartAction{
RunID: currentRunId,
}
testActionRaw, err := json.Marshal(testAction)
require.NoError(t, err)

// We don't expect an error because the run ID is correct
require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "expected no error processing valid remote restart action")

// The restarter should delay before sending an error to `signalRestart`
require.Len(t, remoteRestarter.signalRestart, 0, "expected restarter to delay before signal for restart but channel is already has item in it")

// Now, send an interrupt
remoteRestarter.Interrupt(errors.New("test error"))

// Sleep beyond the interrupt delay, and confirm we don't try to do a restart when we're already shutting down
time.Sleep(restartDelay + 2*time.Second)
require.Len(t, remoteRestarter.signalRestart, 0, "restarter should not have tried to signal for restart when interrupted during restart delay")
}

func TestInterrupt_Multiple(t *testing.T) {
t.Parallel()

mockKnapsack := typesmocks.NewKnapsack(t)
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger())

remoteRestarter := New(mockKnapsack)

// Let the remote restarter run for a bit
go remoteRestarter.Execute()
time.Sleep(3 * time.Second)
remoteRestarter.Interrupt(errors.New("test error"))

// Confirm we can call Interrupt multiple times without blocking
interruptComplete := make(chan struct{})
expectedInterrupts := 3
for i := 0; i < expectedInterrupts; i += 1 {
go func() {
remoteRestarter.Interrupt(nil)
interruptComplete <- struct{}{}
}()
}

receivedInterrupts := 0
for {
if receivedInterrupts >= expectedInterrupts {
break
}

select {
case <-interruptComplete:
receivedInterrupts += 1
continue
case <-time.After(5 * time.Second):
t.Errorf("could not call interrupt multiple times and return within 5 seconds -- received %d interrupts before timeout", receivedInterrupts)
t.FailNow()
}
}

require.Equal(t, expectedInterrupts, receivedInterrupts)
}

0 comments on commit c6fe8b7

Please sign in to comment.