Skip to content

Commit

Permalink
Safe rotate upgrade marker file whenever it's written (#3953)
Browse files Browse the repository at this point in the history
To prevent the upgrade marker file at the actual path from ever showing up as truncated or partially written it now writes the upgrade marker contents to a temporary file first, before safely rotating it into the upgrade marker file's actual path.

---------

Co-authored-by: Pierre HILBERT <pierre.hilbert@elastic.co>
Co-authored-by: Anderson Queiroz <anderson.queiroz@elastic.co>
  • Loading branch information
3 people authored Dec 29, 2023
1 parent 4856600 commit 64bbb2a
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 59 deletions.
21 changes: 19 additions & 2 deletions internal/pkg/agent/application/upgrade/marker_access_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,25 @@ package upgrade
import (
"fmt"
"os"
"path/filepath"
"sync"

"github.com/elastic/elastic-agent-libs/file"
)

func writeMarkerFileCommon(markerFile string, markerBytes []byte, shouldFsync bool) error {
f, err := os.OpenFile(markerFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
f, err := os.CreateTemp(
filepath.Dir(markerFile), fmt.Sprintf("%d-*.tmp", os.Getpid()))
if err != nil {
return fmt.Errorf("failed to open upgrade marker file for writing: %w", err)
}
defer f.Close()
once := sync.Once{}
closeFile := func() {
once.Do(func() {
f.Close()
})
}
defer closeFile()

if _, err := f.Write(markerBytes); err != nil {
return fmt.Errorf("failed to write upgrade marker file: %w", err)
Expand All @@ -27,6 +38,12 @@ func writeMarkerFileCommon(markerFile string, markerBytes []byte, shouldFsync bo
if err := f.Sync(); err != nil {
return fmt.Errorf("failed to sync upgrade marker file to disk: %w", err)
}
// I think we need to close before trying to swap the files on Windows
closeFile()

if err := file.SafeFileRotate(markerFile, f.Name()); err != nil {
return fmt.Errorf("failed to safe rotate upgrade marker file: %w", err)
}

return nil
}

This file was deleted.

123 changes: 123 additions & 0 deletions internal/pkg/agent/application/upgrade/marker_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
package upgrade

import (
"context"
"errors"
"fmt"
"math/rand"
"os"
"path/filepath"
"testing"

"github.com/fsnotify/fsnotify"
"github.com/stretchr/testify/require"
)

Expand All @@ -24,3 +29,121 @@ func TestWriteMarkerFile(t *testing.T) {
require.NoError(t, err)
require.Equal(t, markerBytes, data)
}

func TestWriteMarkerFileWithTruncation(t *testing.T) {
tmpDir := t.TempDir()
testMarkerFile := filepath.Join(tmpDir, markerFilename)

// Watch marker file for the duration of this test, to ensure
// it's never empty (truncated).
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error)
watchFileNotEmpty(t, ctx, testMarkerFile, errCh)
var watchErr error
go func() {
for {
select {
case <-ctx.Done():
return
case err := <-errCh:
watchErr = err
}
}
}()

// Write a long marker file
err := writeMarkerFile(testMarkerFile, randomBytes(40), true)
require.NoError(t, err, "could not write long marker file")

// Get length of file
fileInfo, err := os.Stat(testMarkerFile)
require.NoError(t, err)
originalSize := fileInfo.Size()

err = writeMarkerFile(testMarkerFile, randomBytes(25), true)
require.NoError(t, err)

// Get length of file
fileInfo, err = os.Stat(testMarkerFile)
require.NoError(t, err)
newSize := fileInfo.Size()

// Make sure shorter file is smaller than the original long marker file.
require.Less(t, newSize, originalSize)

// Cancel watch on marker file now that we're at the end of the test and
// check that there were no errors.
cancel()
require.NoError(t, watchErr)
close(errCh)
}

func watchFileNotEmpty(t *testing.T, ctx context.Context, filePath string, errCh chan error) {
watcher, err := fsnotify.NewWatcher()
require.NoError(t, err)

dirPath := filepath.Dir(filePath)
err = watcher.Add(dirPath)
require.NoError(t, err)

// Watch file
go func() {
defer watcher.Close()
for {
select {
case <-ctx.Done():
return
case err, ok := <-watcher.Errors:
if !ok { // Channel was closed (i.e. Watcher.Close() was called).
errCh <- errors.New("fsnotify.Watcher's error channel was closed")
return
}

errCh <- fmt.Errorf("upgrade marker watch returned error: %w", err)
continue
case e, ok := <-watcher.Events:
if !ok { // Channel was closed (i.e. Watcher.Close() was called).
errCh <- errors.New("fsnotify.Watcher's events channel was closed")
return
}

if e.Name != filePath {
// Since we are watching the directory that will contain the file, we
// could receive events here for changes to files other than the file we're
// interested in. We ignore such events.
continue
}

switch {
case e.Op&(fsnotify.Create|fsnotify.Write) != 0:
// File was created or updated; read its length
// and send error if it's zero
fileInfo, err := os.Stat(filePath)
if err != nil {
errCh <- fmt.Errorf("failed to stat file [%s]: %w", filePath, err)
}

if fileInfo.Size() == 0 {
errCh <- fmt.Errorf("file [%s] has size 0", filePath)
}
}
}
}
}()
}

func randomBytes(length int) []byte {
chars := []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZÅÄÖ" +
"abcdefghijklmnopqrstuvwxyzåäö" +
"0123456789" +
"~=+%^*/()[]{}/!@#$?|")

var b []byte
for i := 0; i < length; i++ {
rune := chars[rand.Intn(len(chars))]
b = append(b, byte(rune))
}

return b
}
22 changes: 21 additions & 1 deletion internal/pkg/agent/application/upgrade/marker_access_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/cenkalti/backoff/v4"
)

// TODO: is there an upper limit for this timeout?
const markerAccessTimeout = 10 * time.Second
const markerAccessBackoffInitialInterval = 50 * time.Millisecond
const minMarkerAccessRetries = 5
Expand Down Expand Up @@ -71,6 +72,25 @@ func accessMarkerFileWithRetries(accessFn func() error) error {
defer cancel()

expBackoffWithTimeout := backoff.WithContext(expBackoff, ctx)
start := time.Now()

return backoff.Retry(accessFn, expBackoffWithTimeout)
var duration time.Duration
var count int
var err error
if err = accessFn(); err == nil {
return nil
}

for duration = expBackoffWithTimeout.NextBackOff(); duration != backoff.Stop; duration = expBackoffWithTimeout.NextBackOff() {
time.Sleep(duration)

if err = accessFn(); err == nil {
return nil
}

count++
}

return fmt.Errorf("could not write narker after %s and %d retries. Last error: %w",
time.Since(start), count, err)
}

0 comments on commit 64bbb2a

Please sign in to comment.