Skip to content

Commit

Permalink
Support running multiple osquery instances (#1941)
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany authored Nov 6, 2024
1 parent a50e9fd commit 9ead31a
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 67 deletions.
2 changes: 1 addition & 1 deletion pkg/osquery/runtime/osqueryinstance.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func newInstance(registrationId string, knapsack types.Knapsack, serviceClient s
i := &OsqueryInstance{
registrationId: registrationId,
knapsack: knapsack,
slogger: knapsack.Slogger().With("component", "osquery_instance", "instance_run_id", runId),
slogger: knapsack.Slogger().With("component", "osquery_instance", "registration_id", registrationId, "instance_run_id", runId),
serviceClient: serviceClient,
runId: runId,
}
Expand Down
168 changes: 123 additions & 45 deletions pkg/osquery/runtime/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@ package runtime

import (
"context"
"errors"
"fmt"
"log/slog"
"sync"

"github.com/kolide/launcher/ee/agent/flags/keys"
"github.com/kolide/launcher/ee/agent/types"
"github.com/kolide/launcher/pkg/service"
"golang.org/x/sync/errgroup"
)

const (
defaultRegistrationId = "default"
)

type Runner struct {
instance *OsqueryInstance
instances map[string]*OsqueryInstance // maps registration ID to instance
instanceLock sync.Mutex
slogger *slog.Logger
knapsack types.Knapsack
Expand All @@ -28,7 +30,10 @@ type Runner struct {

func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryInstanceOption) *Runner {
runner := &Runner{
instance: newInstance(defaultRegistrationId, k, serviceClient, opts...),
instances: map[string]*OsqueryInstance{
// For now, we only have one (default) instance and we use it for all queries
defaultRegistrationId: newInstance(defaultRegistrationId, k, serviceClient, opts...),
},
slogger: k.Slogger().With("component", "osquery_runner"),
knapsack: k,
serviceClient: serviceClient,
Expand All @@ -44,60 +49,89 @@ func New(k types.Knapsack, serviceClient service.KolideService, opts ...OsqueryI
}

func (r *Runner) Run() error {
// Ensure we don't try to restart the instance before it's launched
// Create a group to track the workers running each instance
wg, ctx := errgroup.WithContext(context.Background())

// Start each worker for each instance
for registrationId := range r.instances {
id := registrationId
wg.Go(func() error {
if err := r.runInstance(id); err != nil {
r.slogger.Log(ctx, slog.LevelWarn,
"runner terminated running osquery instance unexpectedly, shutting down runner",
"err", err,
)

if err := r.Shutdown(); err != nil {
r.slogger.Log(ctx, slog.LevelError,
"could not shut down runner after failure to run osquery instance",
"err", err,
)
}
return err
}

return nil
})
}

// Wait for all workers to exit
if err := wg.Wait(); err != nil {
return fmt.Errorf("running osquery instances: %w", err)
}

return nil
}

// runInstance starts a worker that launches the instance for the given registration ID, and
// then ensures that instance stays up. It exits if `Shutdown` is called, or if the instance
// exits and cannot be restarted.
func (r *Runner) runInstance(registrationId string) error {
slogger := r.slogger.With("registration_id", registrationId)

// First, launch the instance. Ensure we don't try to restart before launch is complete.
r.instanceLock.Lock()
if err := r.instance.Launch(); err != nil {
r.slogger.Log(context.TODO(), slog.LevelWarn,
"failed to launch osquery instance",
"err", err,
)
instance, ok := r.instances[registrationId]
if !ok {
r.instanceLock.Unlock()
return fmt.Errorf("no instance exists for %s", registrationId)
}
if err := instance.Launch(); err != nil {
r.instanceLock.Unlock()
return fmt.Errorf("starting instance: %w", err)
return fmt.Errorf("starting instance for %s: %w", registrationId, err)
}
r.instanceLock.Unlock()

// This loop waits for the completion of the async routines,
// and either restarts the instance (if Shutdown was not
// called), or stops (if Shutdown was called).
// This loop restarts the instance as necessary. It exits when `Shutdown` is called,
// or if the instance exits and cannot be restarted.
for {
// Wait for async processes to exit
<-r.instance.Exited()
r.slogger.Log(context.TODO(), slog.LevelInfo,
<-instance.Exited()
slogger.Log(context.TODO(), slog.LevelInfo,
"osquery instance exited",
)

select {
case <-r.shutdown:
// Intentional shutdown, this loop can exit
// Intentional shutdown of runner -- exit worker
return nil
default:
// Don't block
// Continue on to restart the instance
}

// Error case -- osquery instance shut down and needs to be restarted
err := r.instance.WaitShutdown()
r.slogger.Log(context.TODO(), slog.LevelInfo,
// The osquery instance either exited on its own, or we called `Restart`.
// Either way, we wait for exit to complete, and then restart the instance.
err := instance.WaitShutdown()
slogger.Log(context.TODO(), slog.LevelInfo,
"unexpected restart of instance",
"err", err,
)

r.instanceLock.Lock()
r.instance = newInstance(defaultRegistrationId, r.knapsack, r.serviceClient, r.opts...)
if err := r.instance.Launch(); err != nil {
r.slogger.Log(context.TODO(), slog.LevelWarn,
"fatal error restarting instance, shutting down",
"err", err,
)
instance = newInstance(registrationId, r.knapsack, r.serviceClient, r.opts...)
r.instances[registrationId] = instance
if err := instance.Launch(); err != nil {
r.instanceLock.Unlock()
if err := r.Shutdown(); err != nil {
r.slogger.Log(context.TODO(), slog.LevelWarn,
"could not perform shutdown",
"err", err,
)
}

// Failed to restart instance -- exit rungroup so launcher can reload
return fmt.Errorf("restarting instance after unexpected exit: %w", err)
return fmt.Errorf("could not restart osquery instance after unexpected exit: %w", err)
}

r.instanceLock.Unlock()
Expand All @@ -107,7 +141,14 @@ func (r *Runner) Run() error {
func (r *Runner) Query(query string) ([]map[string]string, error) {
r.instanceLock.Lock()
defer r.instanceLock.Unlock()
return r.instance.Query(query)

// For now, grab the default (i.e. only) instance
instance, ok := r.instances[defaultRegistrationId]
if !ok {
return nil, errors.New("no default instance exists, cannot query")
}

return instance.Query(query)
}

func (r *Runner) Interrupt(_ error) {
Expand All @@ -129,12 +170,37 @@ func (r *Runner) Shutdown() error {

r.interrupted = true
close(r.shutdown)

if err := r.triggerShutdownForInstances(); err != nil {
return fmt.Errorf("triggering shutdown for instances during runner shutdown: %w", err)
}

return nil
}

// triggerShutdownForInstances asks all instances in `r.instances` to shut down.
func (r *Runner) triggerShutdownForInstances() error {
r.instanceLock.Lock()
defer r.instanceLock.Unlock()
r.instance.BeginShutdown()
if err := r.instance.WaitShutdown(); err != context.Canceled && err != nil {
return fmt.Errorf("while shutting down instance: %w", err)

// Shut down the instances in parallel
shutdownWg, _ := errgroup.WithContext(context.Background())
for registrationId, instance := range r.instances {
id := registrationId
i := instance
shutdownWg.Go(func() error {
i.BeginShutdown()
if err := i.WaitShutdown(); err != context.Canceled && err != nil {
return fmt.Errorf("shutting down instance %s: %w", id, err)
}
return nil
})
}

if err := shutdownWg.Wait(); err != nil {
return fmt.Errorf("shutting down all instances: %+v", err)
}

return nil
}

Expand Down Expand Up @@ -176,11 +242,11 @@ func (r *Runner) Restart() error {
r.slogger.Log(context.TODO(), slog.LevelDebug,
"runner.Restart called",
)
r.instanceLock.Lock()
defer r.instanceLock.Unlock()
// Shut down the instance -- `Run` will start a new one.
r.instance.BeginShutdown()
r.instance.WaitShutdown()

// Shut down the instances -- this will trigger a restart in each `runInstance`.
if err := r.triggerShutdownForInstances(); err != nil {
return fmt.Errorf("triggering shutdown for instances during runner restart: %w", err)
}

return nil
}
Expand All @@ -190,5 +256,17 @@ func (r *Runner) Restart() error {
func (r *Runner) Healthy() error {
r.instanceLock.Lock()
defer r.instanceLock.Unlock()
return r.instance.Healthy()

healthcheckErrs := make([]error, 0)
for registrationId, instance := range r.instances {
if err := instance.Healthy(); err != nil {
healthcheckErrs = append(healthcheckErrs, fmt.Errorf("healthcheck error for %s: %w", registrationId, err))
}
}

if len(healthcheckErrs) > 0 {
return fmt.Errorf("healthchecking all instances: %+v", healthcheckErrs)
}

return nil
}
12 changes: 6 additions & 6 deletions pkg/osquery/runtime/runtime_posix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,24 +134,24 @@ func TestRestart(t *testing.T) {
runner, logBytes, teardown := setupOsqueryInstanceForTests(t)
defer teardown()

previousStats := runner.instance.stats
previousStats := runner.instances[defaultRegistrationId].stats

require.NoError(t, runner.Restart())
waitHealthy(t, runner, logBytes)

require.NotEmpty(t, runner.instance.stats.StartTime, "start time should be set on latest instance stats after restart")
require.NotEmpty(t, runner.instance.stats.ConnectTime, "connect time should be set on latest instance stats after restart")
require.NotEmpty(t, runner.instances[defaultRegistrationId].stats.StartTime, "start time should be set on latest instance stats after restart")
require.NotEmpty(t, runner.instances[defaultRegistrationId].stats.ConnectTime, "connect time should be set on latest instance stats after restart")

require.NotEmpty(t, previousStats.ExitTime, "exit time should be set on last instance stats when restarted")
require.NotEmpty(t, previousStats.Error, "stats instance should have an error on restart")

previousStats = runner.instance.stats
previousStats = runner.instances[defaultRegistrationId].stats

require.NoError(t, runner.Restart())
waitHealthy(t, runner, logBytes)

require.NotEmpty(t, runner.instance.stats.StartTime, "start time should be added to latest instance stats after restart")
require.NotEmpty(t, runner.instance.stats.ConnectTime, "connect time should be added to latest instance stats after restart")
require.NotEmpty(t, runner.instances[defaultRegistrationId].stats.StartTime, "start time should be added to latest instance stats after restart")
require.NotEmpty(t, runner.instances[defaultRegistrationId].stats.ConnectTime, "connect time should be added to latest instance stats after restart")

require.NotEmpty(t, previousStats.ExitTime, "exit time should be set on instance stats when restarted")
require.NotEmpty(t, previousStats.Error, "stats instance should have an error on restart")
Expand Down
Loading

0 comments on commit 9ead31a

Please sign in to comment.