diff --git a/internal/pkg/agent/application/upgrade/marker_access_common.go b/internal/pkg/agent/application/upgrade/marker_access_common.go index d195577fa2d..aa7e1ccbf85 100644 --- a/internal/pkg/agent/application/upgrade/marker_access_common.go +++ b/internal/pkg/agent/application/upgrade/marker_access_common.go @@ -7,14 +7,25 @@ package upgrade import ( "fmt" "os" + "path/filepath" + "sync" + + "github.com/elastic/elastic-agent-libs/file" ) func writeMarkerFileCommon(markerFile string, markerBytes []byte, shouldFsync bool) error { - f, err := os.OpenFile(markerFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + f, err := os.CreateTemp( + filepath.Dir(markerFile), fmt.Sprintf("%d-*.tmp", os.Getpid())) if err != nil { return fmt.Errorf("failed to open upgrade marker file for writing: %w", err) } - defer f.Close() + once := sync.Once{} + closeFile := func() { + once.Do(func() { + f.Close() + }) + } + defer closeFile() if _, err := f.Write(markerBytes); err != nil { return fmt.Errorf("failed to write upgrade marker file: %w", err) @@ -27,6 +38,12 @@ func writeMarkerFileCommon(markerFile string, markerBytes []byte, shouldFsync bo if err := f.Sync(); err != nil { return fmt.Errorf("failed to sync upgrade marker file to disk: %w", err) } + // I think we need to close before trying to swap the files on Windows + closeFile() + + if err := file.SafeFileRotate(markerFile, f.Name()); err != nil { + return fmt.Errorf("failed to safe rotate upgrade marker file: %w", err) + } return nil } diff --git a/internal/pkg/agent/application/upgrade/marker_access_common_test.go b/internal/pkg/agent/application/upgrade/marker_access_common_test.go deleted file mode 100644 index 5a599071794..00000000000 --- a/internal/pkg/agent/application/upgrade/marker_access_common_test.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -// or more contributor license agreements. Licensed under the Elastic License; -// you may not use this file except in compliance with the Elastic License. - -package upgrade - -import ( - "math/rand" - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestWriteMarkerFileWithTruncation(t *testing.T) { - tmpDir := t.TempDir() - testMarkerFile := filepath.Join(tmpDir, markerFilename) - - // Write a long marker file - err := writeMarkerFileCommon(testMarkerFile, randomBytes(40), true) - require.NoError(t, err) - - // Get length of file - fileInfo, err := os.Stat(testMarkerFile) - require.NoError(t, err) - originalSize := fileInfo.Size() - - // Write a shorter marker file - err = writeMarkerFileCommon(testMarkerFile, randomBytes(25), true) - require.NoError(t, err) - - // Get length of file - fileInfo, err = os.Stat(testMarkerFile) - require.NoError(t, err) - newSize := fileInfo.Size() - - // Make sure shorter file has is smaller in length than - // the original long marker file - require.Less(t, newSize, originalSize) -} - -func randomBytes(length int) []byte { - chars := []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZÅÄÖ" + - "abcdefghijklmnopqrstuvwxyzåäö" + - "0123456789" + - "~=+%^*/()[]{}/!@#$?|") - - var b []byte - for i := 0; i < length; i++ { - rune := chars[rand.Intn(len(chars))] - b = append(b, byte(rune)) - } - - return b -} diff --git a/internal/pkg/agent/application/upgrade/marker_access_test.go b/internal/pkg/agent/application/upgrade/marker_access_test.go index 3f1ff637eaa..6da55917c7a 100644 --- a/internal/pkg/agent/application/upgrade/marker_access_test.go +++ b/internal/pkg/agent/application/upgrade/marker_access_test.go @@ -5,10 +5,15 @@ package upgrade import ( + "context" + "errors" + "fmt" + "math/rand" "os" "path/filepath" "testing" + "github.com/fsnotify/fsnotify" "github.com/stretchr/testify/require" ) @@ -24,3 +29,121 @@ func TestWriteMarkerFile(t *testing.T) { require.NoError(t, err) require.Equal(t, markerBytes, data) } + +func TestWriteMarkerFileWithTruncation(t *testing.T) { + tmpDir := t.TempDir() + testMarkerFile := filepath.Join(tmpDir, markerFilename) + + // Watch marker file for the duration of this test, to ensure + // it's never empty (truncated). + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error) + watchFileNotEmpty(t, ctx, testMarkerFile, errCh) + var watchErr error + go func() { + for { + select { + case <-ctx.Done(): + return + case err := <-errCh: + watchErr = err + } + } + }() + + // Write a long marker file + err := writeMarkerFile(testMarkerFile, randomBytes(40), true) + require.NoError(t, err, "could not write long marker file") + + // Get length of file + fileInfo, err := os.Stat(testMarkerFile) + require.NoError(t, err) + originalSize := fileInfo.Size() + + err = writeMarkerFile(testMarkerFile, randomBytes(25), true) + require.NoError(t, err) + + // Get length of file + fileInfo, err = os.Stat(testMarkerFile) + require.NoError(t, err) + newSize := fileInfo.Size() + + // Make sure shorter file is smaller than the original long marker file. + require.Less(t, newSize, originalSize) + + // Cancel watch on marker file now that we're at the end of the test and + // check that there were no errors. + cancel() + require.NoError(t, watchErr) + close(errCh) +} + +func watchFileNotEmpty(t *testing.T, ctx context.Context, filePath string, errCh chan error) { + watcher, err := fsnotify.NewWatcher() + require.NoError(t, err) + + dirPath := filepath.Dir(filePath) + err = watcher.Add(dirPath) + require.NoError(t, err) + + // Watch file + go func() { + defer watcher.Close() + for { + select { + case <-ctx.Done(): + return + case err, ok := <-watcher.Errors: + if !ok { // Channel was closed (i.e. Watcher.Close() was called). + errCh <- errors.New("fsnotify.Watcher's error channel was closed") + return + } + + errCh <- fmt.Errorf("upgrade marker watch returned error: %w", err) + continue + case e, ok := <-watcher.Events: + if !ok { // Channel was closed (i.e. Watcher.Close() was called). + errCh <- errors.New("fsnotify.Watcher's events channel was closed") + return + } + + if e.Name != filePath { + // Since we are watching the directory that will contain the file, we + // could receive events here for changes to files other than the file we're + // interested in. We ignore such events. + continue + } + + switch { + case e.Op&(fsnotify.Create|fsnotify.Write) != 0: + // File was created or updated; read its length + // and send error if it's zero + fileInfo, err := os.Stat(filePath) + if err != nil { + errCh <- fmt.Errorf("failed to stat file [%s]: %w", filePath, err) + } + + if fileInfo.Size() == 0 { + errCh <- fmt.Errorf("file [%s] has size 0", filePath) + } + } + } + } + }() +} + +func randomBytes(length int) []byte { + chars := []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZÅÄÖ" + + "abcdefghijklmnopqrstuvwxyzåäö" + + "0123456789" + + "~=+%^*/()[]{}/!@#$?|") + + var b []byte + for i := 0; i < length; i++ { + rune := chars[rand.Intn(len(chars))] + b = append(b, byte(rune)) + } + + return b +} diff --git a/internal/pkg/agent/application/upgrade/marker_access_windows.go b/internal/pkg/agent/application/upgrade/marker_access_windows.go index cb37f9c0e88..15dfc61adc1 100644 --- a/internal/pkg/agent/application/upgrade/marker_access_windows.go +++ b/internal/pkg/agent/application/upgrade/marker_access_windows.go @@ -15,6 +15,7 @@ import ( "github.com/cenkalti/backoff/v4" ) +// TODO: is there an upper limit for this timeout? const markerAccessTimeout = 10 * time.Second const markerAccessBackoffInitialInterval = 50 * time.Millisecond const minMarkerAccessRetries = 5 @@ -71,6 +72,25 @@ func accessMarkerFileWithRetries(accessFn func() error) error { defer cancel() expBackoffWithTimeout := backoff.WithContext(expBackoff, ctx) + start := time.Now() - return backoff.Retry(accessFn, expBackoffWithTimeout) + var duration time.Duration + var count int + var err error + if err = accessFn(); err == nil { + return nil + } + + for duration = expBackoffWithTimeout.NextBackOff(); duration != backoff.Stop; duration = expBackoffWithTimeout.NextBackOff() { + time.Sleep(duration) + + if err = accessFn(); err == nil { + return nil + } + + count++ + } + + return fmt.Errorf("could not write narker after %s and %d retries. Last error: %w", + time.Since(start), count, err) }