diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go index 4da41e8e55509..96e72c0a5cfa3 100644 --- a/lib/autoupdate/agent/installer.go +++ b/lib/autoupdate/agent/installer.go @@ -31,6 +31,7 @@ import ( "os" "path/filepath" "runtime" + "syscall" "text/template" "time" @@ -50,13 +51,13 @@ var ( // See utils.Extract for more details on how this list is parsed. // Paths must use tarball-style / separators (not filepath). tgzExtractPaths = []utils.ExtractPath{ - {Src: "teleport/examples/systemd/teleport.service", Dst: "etc/systemd/teleport.service"}, - {Src: "teleport/examples", Skip: true}, - {Src: "teleport/install", Skip: true}, - {Src: "teleport/README.md", Dst: "share/README.md"}, - {Src: "teleport/CHANGELOG.md", Dst: "share/CHANGELOG.md"}, - {Src: "teleport/VERSION", Dst: "share/VERSION"}, - {Src: "teleport", Dst: "bin"}, + {Src: "teleport/examples/systemd/teleport.service", Dst: "etc/systemd/teleport.service", DirMode: 0755}, + {Src: "teleport/examples", Skip: true, DirMode: 0755}, + {Src: "teleport/install", Skip: true, DirMode: 0755}, + {Src: "teleport/README.md", Dst: "share/README.md", DirMode: 0755}, + {Src: "teleport/CHANGELOG.md", Dst: "share/CHANGELOG.md", DirMode: 0755}, + {Src: "teleport/VERSION", Dst: "share/VERSION", DirMode: 0755}, + {Src: "teleport", Dst: "bin", DirMode: 0755}, } // servicePath contains the path to the Teleport SystemD service within the version directory. @@ -82,11 +83,9 @@ type LocalInstaller struct { ReservedFreeInstallDisk uint64 } -// ErrLinked is returned when a linked version cannot be removed. -var ErrLinked = errors.New("linked version cannot be removed") - // Remove a Teleport version directory from InstallDir. // This function is idempotent. +// See Installer interface for additional specs. func (li *LocalInstaller) Remove(ctx context.Context, version string) error { // os.RemoveAll is dangerous because it can remove an entire directory tree. // We must validate the version to ensure that we remove only a single path @@ -102,7 +101,7 @@ func (li *LocalInstaller) Remove(ctx context.Context, version string) error { return trace.Errorf("failed to determine if linked: %w", err) } if linked { - return trace.Wrap(ErrLinked) + return trace.Errorf("refusing to remove: %w", ErrLinked) } // invalidate checksum first, to protect against partially-removed @@ -119,7 +118,8 @@ func (li *LocalInstaller) Remove(ctx context.Context, version string) error { // Install a Teleport version directory in InstallDir. // This function is idempotent. -func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error { +// See Installer interface for additional specs. +func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) (err error) { versionDir, err := li.versionDir(version) if err != nil { return trace.Wrap(err) @@ -175,11 +175,17 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string, if err != nil { return trace.Errorf("failed to download teleport: %w", err) } - // Seek to the start of the tgz file after writing if _, err := f.Seek(0, io.SeekStart); err != nil { return trace.Errorf("failed seek to start of download: %w", err) } + + // If interrupted, close the file immediately to stop extracting. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + context.AfterFunc(ctx, func() { + _ = f.Close() // safe to close file multiple times + }) // Check integrity before decompression if !bytes.Equal(newSum, pathSum) { return trace.Errorf("mismatched checksum, download possibly corrupt") @@ -193,6 +199,17 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string, if _, err := f.Seek(0, io.SeekStart); err != nil { return trace.Errorf("failed seek to start: %w", err) } + + // If there's an error after we start extracting, delete the version dir. + defer func() { + if err != nil { + if err := os.RemoveAll(versionDir); err != nil { + li.Log.WarnContext(ctx, "Failed to cleanup broken version extraction.", "error", err, "dir", versionDir) + } + } + }() + + // Extract tgz into version directory. if err := li.extract(ctx, versionDir, f, n); err != nil { return trace.Errorf("failed to extract teleport: %w", err) } @@ -374,51 +391,118 @@ func (li *LocalInstaller) List(ctx context.Context) (versions []string, err erro return versions, nil } -// Link the specified version into the system LinkBinDir. -func (li *LocalInstaller) Link(ctx context.Context, version string) error { +// Link the specified version into the system LinkBinDir and LinkServiceDir. +// The revert function restores the previous linking. +// See Installer interface for additional specs. +func (li *LocalInstaller) Link(ctx context.Context, version string) (revert func(context.Context) bool, err error) { + // setup revert function + type symlink struct { + old, new string + } + var revertLinks []symlink + revert = func(ctx context.Context) bool { + // This function is safe to call repeatedly. + // Returns true only when all symlinks are successfully reverted. + var keep []symlink + for _, l := range revertLinks { + err := renameio.Symlink(l.old, l.new) + if err != nil { + keep = append(keep, l) + li.Log.ErrorContext(ctx, "Failed to revert symlink", "old", l.old, "new", l.new, "err", err) + } + } + revertLinks = keep + return len(revertLinks) == 0 + } + // revert immediately on error, so caller can ignore revert arg + defer func() { + if err != nil { + revert(ctx) + } + }() + versionDir, err := li.versionDir(version) if err != nil { - return trace.Wrap(err) + return revert, trace.Wrap(err) } // ensure target directories exist before trying to create links err = os.MkdirAll(li.LinkBinDir, 0755) if err != nil { - return trace.Wrap(err) + return revert, trace.Wrap(err) } err = os.MkdirAll(li.LinkServiceDir, 0755) if err != nil { - return trace.Wrap(err) + return revert, trace.Wrap(err) } // create binary links + binDir := filepath.Join(versionDir, "bin") entries, err := os.ReadDir(binDir) if err != nil { - return trace.Errorf("failed to find Teleport binary directory: %w", err) + return revert, trace.Errorf("failed to find Teleport binary directory: %w", err) } var linked int for _, entry := range entries { if entry.IsDir() { continue } - err := renameio.Symlink(filepath.Join(binDir, entry.Name()), filepath.Join(li.LinkBinDir, entry.Name())) + oldname := filepath.Join(binDir, entry.Name()) + newname := filepath.Join(li.LinkBinDir, entry.Name()) + orig, err := tryLink(oldname, newname) if err != nil { - return trace.Wrap(err) + return revert, trace.Errorf("failed to create symlink for %s: %w", filepath.Base(oldname), err) + } + if orig != "" { + revertLinks = append(revertLinks, symlink{ + old: orig, + new: newname, + }) } linked++ } if linked == 0 { - return trace.Errorf("no binaries available to link") + return revert, trace.Errorf("no binaries available to link") } // create systemd service link - service := filepath.Join(versionDir, servicePath) - err = renameio.Symlink(service, filepath.Join(li.LinkServiceDir, filepath.Base(servicePath))) + + oldname := filepath.Join(versionDir, servicePath) + newname := filepath.Join(li.LinkServiceDir, filepath.Base(servicePath)) + orig, err := tryLink(oldname, newname) if err != nil { - return trace.Wrap(err) + return revert, trace.Errorf("failed to create symlink for %s: %w", filepath.Base(oldname), err) } - return nil + if orig != "" { + revertLinks = append(revertLinks, symlink{ + old: orig, + new: newname, + }) + } + return revert, nil +} + +// tryLink attempts to create a symlink, atomically replacing an existing link if already present. +// If a non-symlink file or directory exists in newname already, tryLink errors. +func tryLink(oldname, newname string) (orig string, err error) { + orig, err = os.Readlink(newname) + if errors.Is(err, os.ErrInvalid) || + errors.Is(err, syscall.EINVAL) { // workaround missing ErrInvalid wrapper + // important: do not attempt to replace a non-linked install of Teleport + return orig, trace.Errorf("refusing to replace file at %s", newname) + } + if err != nil && !errors.Is(err, os.ErrNotExist) { + return orig, trace.Wrap(err) + } + if orig == oldname { + return "", nil + } + err = renameio.Symlink(oldname, newname) + if err != nil { + return orig, trace.Wrap(err) + } + return orig, nil } // versionDir returns the storage directory for a Teleport version. diff --git a/lib/autoupdate/agent/installer_test.go b/lib/autoupdate/agent/installer_test.go index 2602704208855..d4f58f782dc62 100644 --- a/lib/autoupdate/agent/installer_test.go +++ b/lib/autoupdate/agent/installer_test.go @@ -196,16 +196,18 @@ func TestLocalInstaller_Link(t *testing.T) { const version = "new-version" tests := []struct { - name string - dirs []string - files []string + name string + installDirs []string + installFiles []string + existingLinks []string + existingFiles []string - links []string - errMatch string + resultLinks []string + errMatch string }{ { - name: "present", - dirs: []string{ + name: "present with new links", + installDirs: []string{ "bin", "bin/somedir", "etc", @@ -213,7 +215,7 @@ func TestLocalInstaller_Link(t *testing.T) { "etc/systemd/somedir", "somedir", }, - files: []string{ + installFiles: []string{ "bin/teleport", "bin/tsh", "bin/tbot", @@ -221,7 +223,7 @@ func TestLocalInstaller_Link(t *testing.T) { "README", }, - links: []string{ + resultLinks: []string{ "bin/teleport", "bin/tsh", "bin/tbot", @@ -229,15 +231,102 @@ func TestLocalInstaller_Link(t *testing.T) { }, }, { - name: "no links", - files: []string{"README"}, - dirs: []string{"bin"}, + name: "present with existing links", + installDirs: []string{ + "bin", + "bin/somedir", + "etc", + "etc/systemd", + "etc/systemd/somedir", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + existingLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + "lib/systemd/system/teleport.service", + }, + + resultLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + "lib/systemd/system/teleport.service", + }, + }, + { + name: "conflicting systemd files", + installDirs: []string{ + "bin", + "bin/somedir", + "etc", + "etc/systemd", + "etc/systemd/somedir", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + existingLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + }, + existingFiles: []string{ + "lib/systemd/system/teleport.service", + }, + + errMatch: "refusing", + }, + { + name: "conflicting bin files", + installDirs: []string{ + "bin", + "bin/somedir", + "etc", + "etc/systemd", + "etc/systemd/somedir", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + existingLinks: []string{ + "bin/teleport", + "bin/tbot", + "lib/systemd/system/teleport.service", + }, + existingFiles: []string{ + "bin/tsh", + }, + + errMatch: "refusing", + }, + { + name: "no links", + installFiles: []string{"README"}, + installDirs: []string{"bin"}, errMatch: "no binaries", }, { - name: "no bin directory", - files: []string{"README"}, + name: "no bin directory", + installFiles: []string{"README"}, errMatch: "binary directory", }, @@ -251,16 +340,30 @@ func TestLocalInstaller_Link(t *testing.T) { err := os.MkdirAll(versionDir, 0o755) require.NoError(t, err) - for _, d := range tt.dirs { + // setup files in version directory + for _, d := range tt.installDirs { err := os.Mkdir(filepath.Join(versionDir, d), os.ModePerm) require.NoError(t, err) } - for _, n := range tt.files { + for _, n := range tt.installFiles { err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), os.ModePerm) require.NoError(t, err) } + // setup files in system links directory linkDir := t.TempDir() + for _, n := range tt.existingLinks { + err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm) + require.NoError(t, err) + err = os.Symlink(filepath.Base(n)+".old", filepath.Join(linkDir, n)) + require.NoError(t, err) + } + for _, n := range tt.existingFiles { + err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(linkDir, n), []byte(filepath.Base(n)), os.ModePerm) + require.NoError(t, err) + } installer := &LocalInstaller{ InstallDir: versionsDir, @@ -269,19 +372,50 @@ func TestLocalInstaller_Link(t *testing.T) { Log: slog.Default(), } ctx := context.Background() - err = installer.Link(ctx, version) + revert, err := installer.Link(ctx, version) if tt.errMatch != "" { require.Error(t, err) assert.Contains(t, err.Error(), tt.errMatch) + + // verify automatic revert + for _, link := range tt.existingLinks { + v, err := os.Readlink(filepath.Join(linkDir, link)) + require.NoError(t, err) + require.Equal(t, filepath.Base(link)+".old", v) + } + for _, n := range tt.existingFiles { + v, err := os.ReadFile(filepath.Join(linkDir, n)) + require.NoError(t, err) + require.Equal(t, filepath.Base(n), string(v)) + } + + // ensure revert still succeeds + ok := revert(ctx) + require.True(t, ok) return } require.NoError(t, err) - for _, link := range tt.links { + // verify links + for _, link := range tt.resultLinks { v, err := os.ReadFile(filepath.Join(linkDir, link)) require.NoError(t, err) require.Equal(t, filepath.Base(link), string(v)) } + + // verify manual revert + ok := revert(ctx) + require.True(t, ok) + for _, link := range tt.existingLinks { + v, err := os.Readlink(filepath.Join(linkDir, link)) + require.NoError(t, err) + require.Equal(t, filepath.Base(link)+".old", v) + } + for _, n := range tt.existingFiles { + v, err := os.ReadFile(filepath.Join(linkDir, n)) + require.NoError(t, err) + require.Equal(t, filepath.Base(n), string(v)) + } }) } } @@ -397,7 +531,7 @@ func TestLocalInstaller_Remove(t *testing.T) { ctx := context.Background() if tt.linkedVersion != "" { - err = installer.Link(ctx, tt.linkedVersion) + _, err = installer.Link(ctx, tt.linkedVersion) require.NoError(t, err) } err = installer.Remove(ctx, tt.removeVersion) diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go new file mode 100644 index 0000000000000..eba70aa56a690 --- /dev/null +++ b/lib/autoupdate/agent/process.go @@ -0,0 +1,179 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "errors" + "log/slog" + "os" + "os/exec" + + "github.com/gravitational/trace" +) + +// SystemdService manages a Teleport systemd service. +type SystemdService struct { + // ServiceName specifies the systemd service name. + ServiceName string + // Log contains a logger. + Log *slog.Logger +} + +// Reload a systemd service. +// Attempts a graceful reload before a hard restart. +// See Process interface for more details. +func (s SystemdService) Reload(ctx context.Context) error { + if err := s.checkSystem(ctx); err != nil { + return trace.Wrap(err) + } + // Command error codes < 0 indicate that we are unable to run the command. + // Errors from s.systemctl are logged along with stderr and stdout (debug only). + + // If the service is not running, return ErrNotNeeded. + // Note systemctl reload returns an error if the unit is not active, and + // try-reload-or-restart is too recent of an addition for centos7. + code := s.systemctl(ctx, slog.LevelDebug, "is-active", "--quiet", s.ServiceName) + switch { + case code < 0: + return trace.Errorf("unable to determine if systemd service is active") + case code > 0: + s.Log.WarnContext(ctx, "Teleport systemd service not running.") + return trace.Wrap(ErrNotNeeded) + } + // Attempt graceful reload of running service. + code = s.systemctl(ctx, slog.LevelError, "reload", s.ServiceName) + switch { + case code < 0: + return trace.Errorf("unable to attempt reload of Teleport systemd service") + case code > 0: + // Graceful reload fails, try hard restart. + code = s.systemctl(ctx, slog.LevelError, "try-restart", s.ServiceName) + if code != 0 { + return trace.Errorf("hard restart of Teleport systemd service failed") + } + s.Log.WarnContext(ctx, "Teleport ungracefully restarted. Connections potentially dropped.") + default: + s.Log.InfoContext(ctx, "Teleport gracefully reloaded.") + } + + // TODO(sclevine): Ensure restart was successful and verify healthcheck. + + return nil +} + +// Sync systemd service configuration by running systemctl daemon-reload. +// See Process interface for more details. +func (s SystemdService) Sync(ctx context.Context) error { + if err := s.checkSystem(ctx); err != nil { + return trace.Wrap(err) + } + code := s.systemctl(ctx, slog.LevelError, "daemon-reload") + if code != 0 { + return trace.Errorf("unable to reload systemd configuration") + } + return nil +} + +// checkSystem returns an error if the system is not compatible with this process manager. +func (s SystemdService) checkSystem(ctx context.Context) error { + _, err := os.Stat("/run/systemd/system") + if errors.Is(err, os.ErrNotExist) { + s.Log.ErrorContext(ctx, "This system does not support systemd, which is required by the updater.") + return trace.Wrap(ErrNotSupported) + } + return trace.Wrap(err) +} + +// systemctl returns a systemctl subcommand, converting the output to logs. +// Output sent to stdout is logged at debug level. +// Output sent to stderr is logged at the level specified by errLevel. +func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args ...string) int { + cmd := exec.CommandContext(ctx, "systemctl", args...) + stderr := &lineLogger{ctx: ctx, log: s.Log, level: errLevel} + stdout := &lineLogger{ctx: ctx, log: s.Log, level: slog.LevelDebug} + cmd.Stderr = stderr + cmd.Stdout = stdout + err := cmd.Run() + stderr.Flush() + stdout.Flush() + code := cmd.ProcessState.ExitCode() + + // Treat out-of-range exit code (255) as an error executing the command. + // This allows callers to treat codes that are more likely OS-related as execution errors + // instead of intentionally returned error codes. + if code == 255 { + code = -1 + } + if err != nil { + s.Log.Log(ctx, errLevel, "Failed to run systemctl.", + "args", args, + "code", code, + "error", err) + } + return code +} + +// lineLogger logs each line written to it. +type lineLogger struct { + ctx context.Context + log *slog.Logger + level slog.Level + + last bytes.Buffer +} + +func (w *lineLogger) Write(p []byte) (n int, err error) { + lines := bytes.Split(p, []byte("\n")) + // Finish writing line + if len(lines) > 0 { + n, err = w.last.Write(lines[0]) + lines = lines[1:] + } + // Quit if no newline + if len(lines) == 0 || err != nil { + return n, trace.Wrap(err) + } + + // Newline found, log line + w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant + n += 1 + w.last.Reset() + + // Log lines that are already newline-terminated + for _, line := range lines[:len(lines)-1] { + w.log.Log(w.ctx, w.level, string(line)) //nolint:sloglint // msg cannot be constant + n += len(line) + 1 + } + + // Store remaining line non-newline-terminated line. + n2, err := w.last.Write(lines[len(lines)-1]) + n += n2 + return n, trace.Wrap(err) +} + +// Flush logs any trailing bytes that were never terminated with a newline. +func (w *lineLogger) Flush() { + if w.last.Len() == 0 { + return + } + w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant + w.last.Reset() +} diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go new file mode 100644 index 0000000000000..5ffa70dd0091e --- /dev/null +++ b/lib/autoupdate/agent/process_test.go @@ -0,0 +1,71 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "log/slog" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLineLogger(t *testing.T) { + t.Parallel() + + out := &bytes.Buffer{} + ll := lineLogger{ + ctx: context.Background(), + log: slog.New(slog.NewTextHandler(out, + &slog.HandlerOptions{ReplaceAttr: msgOnly}, + )), + } + + for _, e := range []struct { + v string + n int + }{ + {v: "", n: 0}, + {v: "a", n: 1}, + {v: "b\n", n: 2}, + {v: "c\nd", n: 3}, + {v: "e\nf\ng", n: 5}, + {v: "h", n: 1}, + {v: "", n: 0}, + {v: "\n", n: 1}, + {v: "i\n", n: 2}, + {v: "j", n: 1}, + } { + n, err := ll.Write([]byte(e.v)) + require.NoError(t, err) + require.Equal(t, e.n, n) + } + require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\n", out.String()) + ll.Flush() + require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\nmsg=j\n", out.String()) +} + +func msgOnly(_ []string, a slog.Attr) slog.Attr { + switch a.Key { + case "time", "level": + return slog.Attr{} + } + return slog.Attr{Key: a.Key, Value: a.Value} +} diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index 7071f16e42d15..b82c3c6d419cb 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -135,6 +135,10 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { ReservedFreeTmpDisk: reservedFreeDisk, ReservedFreeInstallDisk: reservedFreeDisk, }, + Process: &SystemdService{ + ServiceName: "teleport.service", + Log: cfg.Log, + }, }, nil } @@ -166,26 +170,58 @@ type Updater struct { ConfigPath string // Installer manages installations of the Teleport agent. Installer Installer + // Process manages a running instance of Teleport. + Process Process } // Installer provides an API for installing Teleport agents. type Installer interface { // Install the Teleport agent at version from the download template. - // This function must be idempotent. + // Install must be idempotent. Install(ctx context.Context, version, template string, flags InstallFlags) error - // Link the Teleport agent at version into the system location. - // This function must be idempotent. - Link(ctx context.Context, version string) error + // Link the Teleport agent at the specified version into the system location. + // The revert function must restore the previous linking, returning false on any failure. + // Link must be idempotent. + // Link's revert function must be idempotent. + Link(ctx context.Context, version string) (revert func(context.Context) bool, err error) // List the installed versions of Teleport. List(ctx context.Context) (versions []string, err error) // Remove the Teleport agent at version. - // This function must be idempotent. + // Must return ErrLinked if unable to remove due to being linked. + // Remove must be idempotent. Remove(ctx context.Context, version string) error } +var ( + // ErrLinked is returned when a linked version cannot be operated on. + ErrLinked = errors.New("version is linked") + // ErrNotNeeded is returned when the operation is not needed. + ErrNotNeeded = errors.New("not needed") + // ErrNotSupported is returned when the operation is not supported on the platform. + ErrNotSupported = errors.New("not supported on this platform") +) + +// Process provides an API for interacting with a running Teleport process. +type Process interface { + // Reload must reload the Teleport process as gracefully as possible. + // If the process is not healthy after reloading, Reload must return an error. + // If the process did not require reloading, Reload must return ErrNotNeeded. + // E.g., if the process is not enabled, or it was already reloaded after the last Sync. + // If the type implementing Process does not support the system process manager, + // Reload must return ErrNotSupported. + Reload(ctx context.Context) error + // Sync must validate and synchronize process configuration. + // After the linked Teleport installation is changed, failure to call Sync without + // error before Reload may result in undefined behavior. + // If the type implementing Process does not support the system process manager, + // Sync must return ErrNotSupported. + Sync(ctx context.Context) error +} + // InstallFlags sets flags for the Teleport installation type InstallFlags int +// TODO(sclevine): add flags for need_restart and selinux config const ( // FlagEnterprise installs enterprise Teleport FlagEnterprise InstallFlags = 1 << iota @@ -215,30 +251,20 @@ type OverrideConfig struct { // This function is idempotent. func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { // Read configuration from update.yaml and override any new values passed as flags. - cfg, err := u.readConfig(u.ConfigPath) + cfg, err := readConfig(u.ConfigPath) if err != nil { return trace.Errorf("failed to read %s: %w", updateConfigName, err) } - if override.Proxy != "" { - cfg.Spec.Proxy = override.Proxy - } - if override.Group != "" { - cfg.Spec.Group = override.Group - } - if override.URLTemplate != "" { - cfg.Spec.URLTemplate = override.URLTemplate - } - cfg.Spec.Enabled = true - if err := validateUpdatesSpec(&cfg.Spec); err != nil { + if err := validateConfigSpec(&cfg.Spec, override); err != nil { return trace.Wrap(err) } // Lookup target version from the proxy. + addr, err := libutils.ParseAddr(cfg.Spec.Proxy) if err != nil { return trace.Errorf("failed to parse proxy server address: %w", err) } - desiredVersion := override.ForceVersion var flags InstallFlags if desiredVersion == "" { @@ -278,7 +304,9 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { u.Log.WarnContext(ctx, "Failed to remove backup version of Teleport before new install.", "error", err) } } - // If the active version and target don't match, kick off upgrade. + + // Install the desired version (or validate existing installation) + template := cfg.Spec.URLTemplate if template == "" { template = cdnURITemplate @@ -287,14 +315,55 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { if err != nil { return trace.Errorf("failed to install: %w", err) } - err = u.Installer.Link(ctx, desiredVersion) + revert, err := u.Installer.Link(ctx, desiredVersion) if err != nil { return trace.Errorf("failed to link: %w", err) } + + // If we fail to revert after this point, the next update/enable will + // fix the link to restore the active version. + + // Sync process configuration after linking. + + if err := u.Process.Sync(ctx); err != nil { + if errors.Is(err, context.Canceled) { + return trace.Errorf("sync canceled") + } + // If sync fails, we may have left the host in a bad state, so we revert linking and re-Sync. + u.Log.ErrorContext(ctx, "Reverting symlinks due to invalid configuration.") + if ok := revert(ctx); !ok { + u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks. Installation likely broken.") + } else if err := u.Process.Sync(ctx); err != nil { + u.Log.ErrorContext(ctx, "Failed to sync configuration after failed restart.", "error", err) + } + u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.") + + return trace.Errorf("failed to validate configuration for new version %q of Teleport: %w", desiredVersion, err) + } + + // Restart Teleport if necessary. + if cfg.Status.ActiveVersion != desiredVersion { + u.Log.InfoContext(ctx, "Target version successfully installed.", "version", desiredVersion) + if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) { + if errors.Is(err, context.Canceled) { + return trace.Errorf("reload canceled") + } + // If reloading Teleport at the new version fails, revert, resync, and reload. + u.Log.ErrorContext(ctx, "Reverting symlinks due to failed restart.") + if ok := revert(ctx); !ok { + u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks to older version. Installation likely broken.") + } else if err := u.Process.Sync(ctx); err != nil { + u.Log.ErrorContext(ctx, "Invalid configuration found after reverting Teleport to older version. Installation likely broken.", "error", err) + } else if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) { + u.Log.ErrorContext(ctx, "Failed to revert Teleport to older version. Installation likely broken.", "error", err) + } + u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.") + + return trace.Errorf("failed to start new version %q of Teleport: %w", desiredVersion, err) + } cfg.Status.BackupVersion = cfg.Status.ActiveVersion cfg.Status.ActiveVersion = desiredVersion - u.Log.InfoContext(ctx, "Target version successfully installed.", "version", desiredVersion) } else { u.Log.InfoContext(ctx, "Target version successfully validated.", "version", desiredVersion) } @@ -302,6 +371,8 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { u.Log.InfoContext(ctx, "Backup version set.", "version", v) } + // Check if manual cleanup might be needed. + versions, err := u.Installer.List(ctx) if err != nil { return trace.Errorf("failed to list installed versions: %w", err) @@ -311,29 +382,19 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { } // Always write the configuration file if enable succeeds. - if err := u.writeConfig(u.ConfigPath, cfg); err != nil { + + cfg.Spec.Enabled = true + if err := writeConfig(u.ConfigPath, cfg); err != nil { return trace.Errorf("failed to write %s: %w", updateConfigName, err) } u.Log.InfoContext(ctx, "Configuration updated.") return nil } -func validateUpdatesSpec(spec *UpdateSpec) error { - if spec.URLTemplate != "" && - !strings.HasPrefix(strings.ToLower(spec.URLTemplate), "https://") { - return trace.Errorf("Teleport download URL must use TLS (https://)") - } - - if spec.Proxy == "" { - return trace.Errorf("Teleport proxy URL must be specified with --proxy or present in %s", updateConfigName) - } - return nil -} - // Disable disables agent auto-updates. // This function is idempotent. func (u *Updater) Disable(ctx context.Context) error { - cfg, err := u.readConfig(u.ConfigPath) + cfg, err := readConfig(u.ConfigPath) if err != nil { return trace.Errorf("failed to read %s: %w", updateConfigName, err) } @@ -342,14 +403,14 @@ func (u *Updater) Disable(ctx context.Context) error { return nil } cfg.Spec.Enabled = false - if err := u.writeConfig(u.ConfigPath, cfg); err != nil { + if err := writeConfig(u.ConfigPath, cfg); err != nil { return trace.Errorf("failed to write %s: %w", updateConfigName, err) } return nil } // readConfig reads UpdateConfig from a file. -func (*Updater) readConfig(path string) (*UpdateConfig, error) { +func readConfig(path string) (*UpdateConfig, error) { f, err := os.Open(path) if errors.Is(err, fs.ErrNotExist) { return &UpdateConfig{ @@ -375,7 +436,7 @@ func (*Updater) readConfig(path string) (*UpdateConfig, error) { } // writeConfig writes UpdateConfig to a file atomically, ensuring the file cannot be corrupted. -func (*Updater) writeConfig(filename string, cfg *UpdateConfig) error { +func writeConfig(filename string, cfg *UpdateConfig) error { opts := []renameio.Option{ renameio.WithPermissions(0755), renameio.WithExistingPermissions(), @@ -391,3 +452,23 @@ func (*Updater) writeConfig(filename string, cfg *UpdateConfig) error { } return trace.Wrap(t.CloseAtomicallyReplace()) } + +func validateConfigSpec(spec *UpdateSpec, override OverrideConfig) error { + if override.Proxy != "" { + spec.Proxy = override.Proxy + } + if override.Group != "" { + spec.Group = override.Group + } + if override.URLTemplate != "" { + spec.URLTemplate = override.URLTemplate + } + if spec.URLTemplate != "" && + !strings.HasPrefix(strings.ToLower(spec.URLTemplate), "https://") { + return trace.Errorf("Teleport download URL must use TLS (https://)") + } + if spec.Proxy == "" { + return trace.Errorf("Teleport proxy URL must be specified with --proxy or present in %s", updateConfigName) + } + return nil +} diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go index e817851fed1f7..8cefd3a59e3e7 100644 --- a/lib/autoupdate/agent/updater_test.go +++ b/lib/autoupdate/agent/updater_test.go @@ -132,11 +132,16 @@ func TestUpdater_Enable(t *testing.T) { userCfg OverrideConfig installErr error flags InstallFlags + syncErr error + reloadErr error removedVersion string installedVersion string installedTemplate string requestGroup string + syncCalls int + reloadCalls int + revertCalls int errMatch string }{ { @@ -152,9 +157,12 @@ func TestUpdater_Enable(t *testing.T) { ActiveVersion: "old-version", }, }, + installedVersion: "16.3.0", installedTemplate: "https://example.com", requestGroup: "group", + syncCalls: 1, + reloadCalls: 1, }, { name: "config from user", @@ -174,8 +182,11 @@ func TestUpdater_Enable(t *testing.T) { URLTemplate: "https://example.com/new", ForceVersion: "new-version", }, + installedVersion: "new-version", installedTemplate: "https://example.com/new", + syncCalls: 1, + reloadCalls: 1, }, { name: "already enabled", @@ -189,8 +200,11 @@ func TestUpdater_Enable(t *testing.T) { ActiveVersion: "old-version", }, }, + installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + syncCalls: 1, + reloadCalls: 1, }, { name: "insecure URL", @@ -201,6 +215,7 @@ func TestUpdater_Enable(t *testing.T) { URLTemplate: "http://example.com", }, }, + errMatch: "URL must use TLS", }, { @@ -213,7 +228,8 @@ func TestUpdater_Enable(t *testing.T) { }, }, installErr: errors.New("install error"), - errMatch: "install error", + + errMatch: "install error", }, { name: "version already installed", @@ -224,8 +240,11 @@ func TestUpdater_Enable(t *testing.T) { ActiveVersion: "16.3.0", }, }, + installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + syncCalls: 1, + reloadCalls: 0, }, { name: "backup version removed on install", @@ -237,9 +256,12 @@ func TestUpdater_Enable(t *testing.T) { BackupVersion: "backup-version", }, }, + installedVersion: "16.3.0", installedTemplate: cdnURITemplate, removedVersion: "backup-version", + syncCalls: 1, + reloadCalls: 1, }, { name: "backup version kept for validation", @@ -251,26 +273,56 @@ func TestUpdater_Enable(t *testing.T) { BackupVersion: "backup-version", }, }, + installedVersion: "16.3.0", installedTemplate: cdnURITemplate, removedVersion: "", + syncCalls: 1, + reloadCalls: 0, }, { - name: "config does not exist", + name: "config does not exist", + installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + syncCalls: 1, + reloadCalls: 1, }, { name: "FIPS and Enterprise flags", flags: FlagEnterprise | FlagFIPS, installedVersion: "16.3.0", installedTemplate: cdnURITemplate, + syncCalls: 1, + reloadCalls: 1, }, { name: "invalid metadata", cfg: &UpdateConfig{}, errMatch: "invalid", }, + { + name: "sync fails", + syncErr: errors.New("sync error"), + + installedVersion: "16.3.0", + installedTemplate: cdnURITemplate, + syncCalls: 2, + reloadCalls: 0, + revertCalls: 1, + errMatch: "sync error", + }, + { + name: "reload fails", + reloadErr: errors.New("reload error"), + + installedVersion: "16.3.0", + installedTemplate: cdnURITemplate, + syncCalls: 2, + reloadCalls: 2, + revertCalls: 1, + errMatch: "reload error", + }, } for _, tt := range tests { @@ -320,6 +372,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion string removedVersion string installedFlags InstallFlags + revertCalls int ) updater.Installer = &testInstaller{ FuncInstall: func(_ context.Context, version, template string, flags InstallFlags) error { @@ -328,9 +381,12 @@ func TestUpdater_Enable(t *testing.T) { installedFlags = flags return tt.installErr }, - FuncLink: func(_ context.Context, version string) error { + FuncLink: func(_ context.Context, version string) (revert func(context.Context) bool, err error) { linkedVersion = version - return nil + return func(_ context.Context) bool { + revertCalls++ + return true + }, nil }, FuncList: func(_ context.Context) (versions []string, err error) { return []string{"old"}, nil @@ -340,6 +396,20 @@ func TestUpdater_Enable(t *testing.T) { return nil }, } + var ( + syncCalls int + reloadCalls int + ) + updater.Process = &testProcess{ + FuncSync: func(_ context.Context) error { + syncCalls++ + return tt.syncErr + }, + FuncReload: func(_ context.Context) error { + reloadCalls++ + return tt.reloadErr + }, + } ctx := context.Background() err = updater.Enable(ctx, tt.userCfg) @@ -355,6 +425,9 @@ func TestUpdater_Enable(t *testing.T) { require.Equal(t, tt.removedVersion, removedVersion) require.Equal(t, tt.flags, installedFlags) require.Equal(t, tt.requestGroup, requestedGroup) + require.Equal(t, tt.syncCalls, syncCalls) + require.Equal(t, tt.reloadCalls, reloadCalls) + require.Equal(t, tt.revertCalls, revertCalls) data, err := os.ReadFile(cfgPath) require.NoError(t, err) @@ -377,7 +450,7 @@ func blankTestAddr(s []byte) []byte { type testInstaller struct { FuncInstall func(ctx context.Context, version, template string, flags InstallFlags) error FuncRemove func(ctx context.Context, version string) error - FuncLink func(ctx context.Context, version string) error + FuncLink func(ctx context.Context, version string) (revert func(context.Context) bool, err error) FuncList func(ctx context.Context) (versions []string, err error) } @@ -389,10 +462,23 @@ func (ti *testInstaller) Remove(ctx context.Context, version string) error { return ti.FuncRemove(ctx, version) } -func (ti *testInstaller) Link(ctx context.Context, version string) error { +func (ti *testInstaller) Link(ctx context.Context, version string) (revert func(context.Context) bool, err error) { return ti.FuncLink(ctx, version) } func (ti *testInstaller) List(ctx context.Context) (versions []string, err error) { return ti.FuncList(ctx) } + +type testProcess struct { + FuncReload func(ctx context.Context) error + FuncSync func(ctx context.Context) error +} + +func (tp *testProcess) Reload(ctx context.Context) error { + return tp.FuncReload(ctx) +} + +func (tp *testProcess) Sync(ctx context.Context) error { + return tp.FuncSync(ctx) +} diff --git a/lib/utils/unpack.go b/lib/utils/unpack.go index 78b111daf8992..14b213f08a173 100644 --- a/lib/utils/unpack.go +++ b/lib/utils/unpack.go @@ -50,7 +50,8 @@ func Extract(r io.Reader, dir string, paths ...ExtractPath) error { } else if err != nil { return trace.Wrap(err) } - if ok := filterHeader(header, paths); !ok { + dirMode, ok := filterHeader(header, paths) + if !ok { continue } err = sanitizeTarPath(header, dir) @@ -58,7 +59,7 @@ func Extract(r io.Reader, dir string, paths ...ExtractPath) error { return trace.Wrap(err) } - if err := extractFile(tarball, header, dir); err != nil { + if err := extractFile(tarball, header, dir, dirMode); err != nil { return trace.Wrap(err) } } @@ -74,11 +75,15 @@ type ExtractPath struct { Src, Dst string // Skip extracting the Src path and ignore Dst. Skip bool + // DirMode is the file mode for implicit parent directories in Dst. + DirMode os.FileMode } // filterHeader modifies the tar header by filtering it through the ExtractPaths. // filterHeader returns false if the tar header should be skipped. -func filterHeader(hdr *tar.Header, paths []ExtractPath) (include bool) { +// If no paths are provided, filterHeader assumes the header should be included, and sets +// the mode for implicit parent directories to teleport.DirMaskSharedGroup. +func filterHeader(hdr *tar.Header, paths []ExtractPath) (dirMode os.FileMode, include bool) { name := path.Clean(hdr.Name) for _, p := range paths { src := path.Clean(p.Src) @@ -98,14 +103,14 @@ func filterHeader(hdr *tar.Header, paths []ExtractPath) (include bool) { dst += "/" // tar directory headers end in / } hdr.Name = dst - return !p.Skip + return p.DirMode, !p.Skip default: // If name is a file, then // if src is an exact match to the file name, assume src is a file and write directly to dst, // otherwise, assume src is a directory prefix, and replace that prefix with dst. if src == name { hdr.Name = path.Clean(p.Dst) - return !p.Skip + return p.DirMode, !p.Skip } if src != "/" { src += "/" // ensure HasPrefix does not match partial names @@ -114,26 +119,26 @@ func filterHeader(hdr *tar.Header, paths []ExtractPath) (include bool) { continue } hdr.Name = path.Join(p.Dst, strings.TrimPrefix(name, src)) - return !p.Skip + return p.DirMode, !p.Skip } } - return len(paths) == 0 + return teleport.DirMaskSharedGroup, len(paths) == 0 } // extractFile extracts a single file or directory from tarball into dir. // Uses header to determine the type of item to create // Based on https://github.com/mholt/archiver -func extractFile(tarball *tar.Reader, header *tar.Header, dir string) error { +func extractFile(tarball *tar.Reader, header *tar.Header, dir string, dirMode os.FileMode) error { switch header.Typeflag { case tar.TypeDir: - return withDir(filepath.Join(dir, header.Name), nil) + return withDir(filepath.Join(dir, header.Name), dirMode, nil) case tar.TypeBlock, tar.TypeChar, tar.TypeReg, tar.TypeFifo: - return writeFile(filepath.Join(dir, header.Name), tarball, header.FileInfo().Mode()) + return writeFile(filepath.Join(dir, header.Name), tarball, header.FileInfo().Mode(), dirMode) case tar.TypeLink: - return writeHardLink(filepath.Join(dir, header.Name), filepath.Join(dir, header.Linkname)) + return writeHardLink(filepath.Join(dir, header.Name), filepath.Join(dir, header.Linkname), dirMode) case tar.TypeSymlink: - return writeSymbolicLink(filepath.Join(dir, header.Name), header.Linkname) + return writeSymbolicLink(filepath.Join(dir, header.Name), header.Linkname, dirMode) default: log.Warnf("Unsupported type flag %v for %v.", header.Typeflag, header.Name) } @@ -168,8 +173,8 @@ func sanitizeTarPath(header *tar.Header, dir string) error { return nil } -func writeFile(path string, r io.Reader, mode os.FileMode) error { - err := withDir(path, func() error { +func writeFile(path string, r io.Reader, mode, dirMode os.FileMode) error { + err := withDir(path, dirMode, func() error { // Create file only if it does not exist to prevent overwriting existing // files (like session recordings). out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode) @@ -182,24 +187,24 @@ func writeFile(path string, r io.Reader, mode os.FileMode) error { return trace.Wrap(err) } -func writeSymbolicLink(path string, target string) error { - err := withDir(path, func() error { +func writeSymbolicLink(path, target string, dirMode os.FileMode) error { + err := withDir(path, dirMode, func() error { err := os.Symlink(target, path) return trace.ConvertSystemError(err) }) return trace.Wrap(err) } -func writeHardLink(path string, target string) error { - err := withDir(path, func() error { +func writeHardLink(path, target string, dirMode os.FileMode) error { + err := withDir(path, dirMode, func() error { err := os.Link(target, path) return trace.ConvertSystemError(err) }) return trace.Wrap(err) } -func withDir(path string, fn func() error) error { - err := os.MkdirAll(filepath.Dir(path), teleport.DirMaskSharedGroup) +func withDir(path string, mode os.FileMode, fn func() error) error { + err := os.MkdirAll(filepath.Dir(path), mode) if err != nil { return trace.ConvertSystemError(err) } diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go index 300da6736471a..2adce83a1877c 100644 --- a/tool/teleport-update/main.go +++ b/tool/teleport-update/main.go @@ -61,6 +61,8 @@ const ( versionsDirName = "versions" // lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution. lockFileName = ".lock" + // defaultLinkDir is the default location where Teleport binaries and services are linked. + defaultLinkDir = "/usr/local" ) var plog = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentUpdater) @@ -98,7 +100,7 @@ func Run(args []string) error { app.Flag("log-format", "Controls the format of output logs. Can be `json` or `text`. Defaults to `text`."). Default(libutils.LogFormatText).EnumVar(&ccfg.LogFormat, libutils.LogFormatJSON, libutils.LogFormatText) app.Flag("link-dir", "Directory to create system symlinks to binaries and services."). - Default(filepath.Join("usr", "local")).Hidden().StringVar(&ccfg.LinkDir) + Default(defaultLinkDir).Hidden().StringVar(&ccfg.LinkDir) app.HelpFlag.Short('h')