diff --git a/cmd/launcher/internal/updater/updater.go b/cmd/launcher/internal/updater/updater.go index 6e293660d..660816bac 100644 --- a/cmd/launcher/internal/updater/updater.go +++ b/cmd/launcher/internal/updater/updater.go @@ -62,9 +62,12 @@ func NewUpdater( return nil, err } + ctx, cancel := context.WithCancel(ctx) + updateCmd := &updaterCmd{ updater: updater, ctx: ctx, + cancel: cancel, stopChan: make(chan bool), config: config, runUpdaterRetryInterval: 30 * time.Minute, @@ -84,6 +87,7 @@ type updater interface { type updaterCmd struct { updater updater ctx context.Context + cancel context.CancelFunc stopChan chan bool stopExecution func() config *UpdaterConfig @@ -155,4 +159,6 @@ func (u *updaterCmd) interrupt(err error) { if u.stopExecution != nil { u.stopExecution() } + + u.cancel() } diff --git a/cmd/launcher/internal/updater/updater_test.go b/cmd/launcher/internal/updater/updater_test.go index d9235f20d..8ba35f1f0 100644 --- a/cmd/launcher/internal/updater/updater_test.go +++ b/cmd/launcher/internal/updater/updater_test.go @@ -114,6 +114,7 @@ func Test_updaterCmd_execute(t *testing.T) { u := &updaterCmd{ updater: tt.fields.updater, ctx: ctx, + cancel: cancelCtx, stopChan: tt.fields.stopChan, config: tt.fields.config, runUpdaterRetryInterval: tt.fields.runUpdaterRetryInterval, @@ -194,9 +195,13 @@ func Test_updaterCmd_interrupt(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + u := &updaterCmd{ stopChan: tt.fields.stopChan, config: tt.fields.config, + ctx: ctx, + cancel: cancel, } // using this wait group to ensure that something gets received on u.StopChan diff --git a/pkg/autoupdate/tuf/autoupdate.go b/pkg/autoupdate/tuf/autoupdate.go index 0aa30169d..6c1077dff 100644 --- a/pkg/autoupdate/tuf/autoupdate.go +++ b/pkg/autoupdate/tuf/autoupdate.go @@ -87,7 +87,7 @@ func NewTufAutoupdater(k types.Knapsack, metadataHttpClient *http.Client, mirror checkInterval: k.AutoupdateInterval(), store: k.AutoupdateErrorsStore(), osquerier: osquerier, - osquerierRetryInterval: 1 * time.Minute, + osquerierRetryInterval: 30 * time.Second, logger: log.NewNopLogger(), } @@ -184,10 +184,10 @@ func (ta *TufAutoupdater) Execute() (err error) { } func (ta *TufAutoupdater) Interrupt(_ error) { - ta.interrupt <- struct{}{} if err := ta.libraryManager.Close(); err != nil { level.Debug(ta.logger).Log("msg", "could not close library on interrupt", "err", err) } + ta.interrupt <- struct{}{} } // tidyLibrary gets the current running version for each binary (so that the current version is not removed) @@ -306,6 +306,10 @@ func (ta *TufAutoupdater) downloadUpdate(binary autoupdatableBinary, targets dat return "", fmt.Errorf("could not find release: %w", err) } + if ta.libraryManager.Available(binary, release) { + return "", nil + } + // Get the current running version if available -- don't error out if we can't // get it, since the worst case is that we download an update whose version matches // our install version. @@ -315,10 +319,6 @@ func (ta *TufAutoupdater) downloadUpdate(binary autoupdatableBinary, targets dat return "", nil } - if ta.libraryManager.Available(binary, release) { - return "", nil - } - if err := ta.libraryManager.AddToLibrary(binary, currentVersion, release, releaseMetadata); err != nil { return "", fmt.Errorf("could not add release %s for binary %s to library: %w", release, binary, err) } diff --git a/pkg/autoupdate/tuf/library_manager.go b/pkg/autoupdate/tuf/library_manager.go index 2ded06a4e..457cdd286 100644 --- a/pkg/autoupdate/tuf/library_manager.go +++ b/pkg/autoupdate/tuf/library_manager.go @@ -69,6 +69,12 @@ func newUpdateLibraryManager(mirrorUrl string, mirrorClient *http.Client, baseDi // Close cleans up the temporary staging directory func (ulm *updateLibraryManager) Close() error { + // Acquire lock to ensure we aren't interrupting an ongoing operation + for _, binary := range binaries { + ulm.lock.Lock(binary) + defer ulm.lock.Unlock(binary) + } + if err := os.RemoveAll(ulm.stagingDir); err != nil { return fmt.Errorf("could not remove staging dir %s: %w", ulm.stagingDir, err) } diff --git a/pkg/sendbuffer/sendbuffer.go b/pkg/sendbuffer/sendbuffer.go index 5b2fcc47f..09cec52ca 100644 --- a/pkg/sendbuffer/sendbuffer.go +++ b/pkg/sendbuffer/sendbuffer.go @@ -141,7 +141,7 @@ func (sb *SendBuffer) Run(ctx context.Context) error { case <-ticker.C: continue case <-ctx.Done(): - break + return nil } } } diff --git a/pkg/traces/exporter/exporter.go b/pkg/traces/exporter/exporter.go index 3e976f537..d55b5c66c 100644 --- a/pkg/traces/exporter/exporter.go +++ b/pkg/traces/exporter/exporter.go @@ -56,6 +56,7 @@ type TraceExporter struct { disableIngestTLS bool enabled bool traceSamplingRate float64 + interrupt chan struct{} } // NewTraceExporter sets up our traces to be exported via OTLP over HTTP. @@ -86,6 +87,7 @@ func NewTraceExporter(ctx context.Context, k types.Knapsack, client osquery.Quer disableIngestTLS: k.DisableTraceIngestTLS(), enabled: k.ExportTraces(), traceSamplingRate: k.TraceSamplingRate(), + interrupt: make(chan struct{}), } // Observe ExportTraces and IngestServerURL changes to know when to start/stop exporting, and where @@ -248,8 +250,7 @@ func (t *TraceExporter) setNewGlobalProvider() { // Execute is a no-op -- the exporter is already running in the background. The TraceExporter // otherwise only responds to control server events. func (t *TraceExporter) Execute() error { - // Does nothing, just waiting for launcher to exit - <-context.Background().Done() + <-t.interrupt return nil } @@ -257,6 +258,8 @@ func (t *TraceExporter) Interrupt(_ error) { if t.provider != nil { t.provider.Shutdown(context.Background()) } + + t.interrupt <- struct{}{} } // Update satisfies control.subscriber interface -- looks at changes to the `observability_ingest` subsystem,