Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[teleport-update] Add support for reloading the agent & reverting symlinks on failed reload #47929

Merged
merged 24 commits into from
Nov 4, 2024
Merged
136 changes: 110 additions & 26 deletions lib/autoupdate/agent/installer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"os"
"path/filepath"
"runtime"
"syscall"
"text/template"
"time"

Expand All @@ -50,13 +51,13 @@ var (
// See utils.Extract for more details on how this list is parsed.
// Paths must use tarball-style / separators (not filepath).
tgzExtractPaths = []utils.ExtractPath{
{Src: "teleport/examples/systemd/teleport.service", Dst: "etc/systemd/teleport.service"},
{Src: "teleport/examples", Skip: true},
{Src: "teleport/install", Skip: true},
{Src: "teleport/README.md", Dst: "share/README.md"},
{Src: "teleport/CHANGELOG.md", Dst: "share/CHANGELOG.md"},
{Src: "teleport/VERSION", Dst: "share/VERSION"},
{Src: "teleport", Dst: "bin"},
{Src: "teleport/examples/systemd/teleport.service", Dst: "etc/systemd/teleport.service", DirMode: 0755},
{Src: "teleport/examples", Skip: true, DirMode: 0755},
{Src: "teleport/install", Skip: true, DirMode: 0755},
{Src: "teleport/README.md", Dst: "share/README.md", DirMode: 0755},
{Src: "teleport/CHANGELOG.md", Dst: "share/CHANGELOG.md", DirMode: 0755},
{Src: "teleport/VERSION", Dst: "share/VERSION", DirMode: 0755},
{Src: "teleport", Dst: "bin", DirMode: 0755},
}

// servicePath contains the path to the Teleport SystemD service within the version directory.
Expand All @@ -82,11 +83,9 @@ type LocalInstaller struct {
ReservedFreeInstallDisk uint64
}

// ErrLinked is returned when a linked version cannot be removed.
var ErrLinked = errors.New("linked version cannot be removed")

// Remove a Teleport version directory from InstallDir.
// This function is idempotent.
// See Installer interface for additional specs.
func (li *LocalInstaller) Remove(ctx context.Context, version string) error {
// os.RemoveAll is dangerous because it can remove an entire directory tree.
// We must validate the version to ensure that we remove only a single path
Expand All @@ -102,7 +101,7 @@ func (li *LocalInstaller) Remove(ctx context.Context, version string) error {
return trace.Errorf("failed to determine if linked: %w", err)
}
if linked {
return trace.Wrap(ErrLinked)
return trace.Errorf("refusing to remove: %w", ErrLinked)
}

// invalidate checksum first, to protect against partially-removed
Expand All @@ -119,7 +118,8 @@ func (li *LocalInstaller) Remove(ctx context.Context, version string) error {

// Install a Teleport version directory in InstallDir.
// This function is idempotent.
func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error {
// See Installer interface for additional specs.
func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) (err error) {
versionDir, err := li.versionDir(version)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -175,11 +175,17 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string,
if err != nil {
return trace.Errorf("failed to download teleport: %w", err)
}

// Seek to the start of the tgz file after writing
if _, err := f.Seek(0, io.SeekStart); err != nil {
return trace.Errorf("failed seek to start of download: %w", err)
}

// If interrupted, close the file immediately to stop extracting.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
context.AfterFunc(ctx, func() {
_ = f.Close() // safe to close file multiple times
})
// Check integrity before decompression
if !bytes.Equal(newSum, pathSum) {
return trace.Errorf("mismatched checksum, download possibly corrupt")
Expand All @@ -193,6 +199,17 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string,
if _, err := f.Seek(0, io.SeekStart); err != nil {
return trace.Errorf("failed seek to start: %w", err)
}

// If there's an error after we start extracting, delete the version dir.
defer func() {
if err != nil {
if err := os.RemoveAll(versionDir); err != nil {
li.Log.WarnContext(ctx, "Failed to cleanup broken version extraction.", "error", err, "dir", versionDir)
}
}
}()

// Extract tgz into version directory.
if err := li.extract(ctx, versionDir, f, n); err != nil {
return trace.Errorf("failed to extract teleport: %w", err)
}
Expand Down Expand Up @@ -374,51 +391,118 @@ func (li *LocalInstaller) List(ctx context.Context) (versions []string, err erro
return versions, nil
}

// Link the specified version into the system LinkBinDir.
func (li *LocalInstaller) Link(ctx context.Context, version string) error {
// Link the specified version into the system LinkBinDir and LinkServiceDir.
// The revert function restores the previous linking.
// See Installer interface for additional specs.
func (li *LocalInstaller) Link(ctx context.Context, version string) (revert func(context.Context) bool, err error) {
// setup revert function
type symlink struct {
old, new string
}
var revertLinks []symlink
revert = func(ctx context.Context) bool {
// This function is safe to call repeatedly.
// Returns true only when all symlinks are successfully reverted.
var keep []symlink
for _, l := range revertLinks {
err := renameio.Symlink(l.old, l.new)
if err != nil {
keep = append(keep, l)
li.Log.ErrorContext(ctx, "Failed to revert symlink", "old", l.old, "new", l.new, "err", err)
}
}
revertLinks = keep
return len(revertLinks) == 0
}
// revert immediately on error, so caller can ignore revert arg
defer func() {
if err != nil {
revert(ctx)
}
}()

versionDir, err := li.versionDir(version)
if err != nil {
return trace.Wrap(err)
return revert, trace.Wrap(err)
}

// ensure target directories exist before trying to create links
err = os.MkdirAll(li.LinkBinDir, 0755)
if err != nil {
return trace.Wrap(err)
return revert, trace.Wrap(err)
}
err = os.MkdirAll(li.LinkServiceDir, 0755)
if err != nil {
return trace.Wrap(err)
return revert, trace.Wrap(err)
}

// create binary links

binDir := filepath.Join(versionDir, "bin")
entries, err := os.ReadDir(binDir)
if err != nil {
return trace.Errorf("failed to find Teleport binary directory: %w", err)
return revert, trace.Errorf("failed to find Teleport binary directory: %w", err)
}
var linked int
for _, entry := range entries {
if entry.IsDir() {
continue
}
err := renameio.Symlink(filepath.Join(binDir, entry.Name()), filepath.Join(li.LinkBinDir, entry.Name()))
oldname := filepath.Join(binDir, entry.Name())
newname := filepath.Join(li.LinkBinDir, entry.Name())
orig, err := tryLink(oldname, newname)
if err != nil {
return trace.Wrap(err)
return revert, trace.Errorf("failed to create symlink for %s: %w", filepath.Base(oldname), err)
}
if orig != "" {
revertLinks = append(revertLinks, symlink{
old: orig,
new: newname,
})
}
linked++
}
if linked == 0 {
return trace.Errorf("no binaries available to link")
return revert, trace.Errorf("no binaries available to link")
}

// create systemd service link
service := filepath.Join(versionDir, servicePath)
err = renameio.Symlink(service, filepath.Join(li.LinkServiceDir, filepath.Base(servicePath)))

oldname := filepath.Join(versionDir, servicePath)
newname := filepath.Join(li.LinkServiceDir, filepath.Base(servicePath))
orig, err := tryLink(oldname, newname)
if err != nil {
return trace.Wrap(err)
return revert, trace.Errorf("failed to create symlink for %s: %w", filepath.Base(oldname), err)
}
return nil
if orig != "" {
revertLinks = append(revertLinks, symlink{
old: orig,
new: newname,
})
}
return revert, nil
}

// tryLink attempts to create a symlink, atomically replacing an existing link if already present.
// If a non-symlink file or directory exists in newname already, tryLink errors.
func tryLink(oldname, newname string) (orig string, err error) {
orig, err = os.Readlink(newname)
if errors.Is(err, os.ErrInvalid) ||
errors.Is(err, syscall.EINVAL) { // workaround missing ErrInvalid wrapper
// important: do not attempt to replace a non-linked install of Teleport
return orig, trace.Errorf("refusing to replace file at %s", newname)
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return orig, trace.Wrap(err)
}
if orig == oldname {
return "", nil
}
err = renameio.Symlink(oldname, newname)
if err != nil {
return orig, trace.Wrap(err)
}
return orig, nil
}

// versionDir returns the storage directory for a Teleport version.
Expand Down
Loading
Loading