diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go
index 4da41e8e55509..96e72c0a5cfa3 100644
--- a/lib/autoupdate/agent/installer.go
+++ b/lib/autoupdate/agent/installer.go
@@ -31,6 +31,7 @@ import (
"os"
"path/filepath"
"runtime"
+ "syscall"
"text/template"
"time"
@@ -50,13 +51,13 @@ var (
// See utils.Extract for more details on how this list is parsed.
// Paths must use tarball-style / separators (not filepath).
tgzExtractPaths = []utils.ExtractPath{
- {Src: "teleport/examples/systemd/teleport.service", Dst: "etc/systemd/teleport.service"},
- {Src: "teleport/examples", Skip: true},
- {Src: "teleport/install", Skip: true},
- {Src: "teleport/README.md", Dst: "share/README.md"},
- {Src: "teleport/CHANGELOG.md", Dst: "share/CHANGELOG.md"},
- {Src: "teleport/VERSION", Dst: "share/VERSION"},
- {Src: "teleport", Dst: "bin"},
+ {Src: "teleport/examples/systemd/teleport.service", Dst: "etc/systemd/teleport.service", DirMode: 0755},
+ {Src: "teleport/examples", Skip: true, DirMode: 0755},
+ {Src: "teleport/install", Skip: true, DirMode: 0755},
+ {Src: "teleport/README.md", Dst: "share/README.md", DirMode: 0755},
+ {Src: "teleport/CHANGELOG.md", Dst: "share/CHANGELOG.md", DirMode: 0755},
+ {Src: "teleport/VERSION", Dst: "share/VERSION", DirMode: 0755},
+ {Src: "teleport", Dst: "bin", DirMode: 0755},
}
// servicePath contains the path to the Teleport SystemD service within the version directory.
@@ -82,11 +83,9 @@ type LocalInstaller struct {
ReservedFreeInstallDisk uint64
}
-// ErrLinked is returned when a linked version cannot be removed.
-var ErrLinked = errors.New("linked version cannot be removed")
-
// Remove a Teleport version directory from InstallDir.
// This function is idempotent.
+// See Installer interface for additional specs.
func (li *LocalInstaller) Remove(ctx context.Context, version string) error {
// os.RemoveAll is dangerous because it can remove an entire directory tree.
// We must validate the version to ensure that we remove only a single path
@@ -102,7 +101,7 @@ func (li *LocalInstaller) Remove(ctx context.Context, version string) error {
return trace.Errorf("failed to determine if linked: %w", err)
}
if linked {
- return trace.Wrap(ErrLinked)
+ return trace.Errorf("refusing to remove: %w", ErrLinked)
}
// invalidate checksum first, to protect against partially-removed
@@ -119,7 +118,8 @@ func (li *LocalInstaller) Remove(ctx context.Context, version string) error {
// Install a Teleport version directory in InstallDir.
// This function is idempotent.
-func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error {
+// See Installer interface for additional specs.
+func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) (err error) {
versionDir, err := li.versionDir(version)
if err != nil {
return trace.Wrap(err)
@@ -175,11 +175,17 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string,
if err != nil {
return trace.Errorf("failed to download teleport: %w", err)
}
-
// Seek to the start of the tgz file after writing
if _, err := f.Seek(0, io.SeekStart); err != nil {
return trace.Errorf("failed seek to start of download: %w", err)
}
+
+ // If interrupted, close the file immediately to stop extracting.
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ context.AfterFunc(ctx, func() {
+ _ = f.Close() // safe to close file multiple times
+ })
// Check integrity before decompression
if !bytes.Equal(newSum, pathSum) {
return trace.Errorf("mismatched checksum, download possibly corrupt")
@@ -193,6 +199,17 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string,
if _, err := f.Seek(0, io.SeekStart); err != nil {
return trace.Errorf("failed seek to start: %w", err)
}
+
+ // If there's an error after we start extracting, delete the version dir.
+ defer func() {
+ if err != nil {
+ if err := os.RemoveAll(versionDir); err != nil {
+ li.Log.WarnContext(ctx, "Failed to cleanup broken version extraction.", "error", err, "dir", versionDir)
+ }
+ }
+ }()
+
+ // Extract tgz into version directory.
if err := li.extract(ctx, versionDir, f, n); err != nil {
return trace.Errorf("failed to extract teleport: %w", err)
}
@@ -374,51 +391,118 @@ func (li *LocalInstaller) List(ctx context.Context) (versions []string, err erro
return versions, nil
}
-// Link the specified version into the system LinkBinDir.
-func (li *LocalInstaller) Link(ctx context.Context, version string) error {
+// Link the specified version into the system LinkBinDir and LinkServiceDir.
+// The revert function restores the previous linking.
+// See Installer interface for additional specs.
+func (li *LocalInstaller) Link(ctx context.Context, version string) (revert func(context.Context) bool, err error) {
+ // setup revert function
+ type symlink struct {
+ old, new string
+ }
+ var revertLinks []symlink
+ revert = func(ctx context.Context) bool {
+ // This function is safe to call repeatedly.
+ // Returns true only when all symlinks are successfully reverted.
+ var keep []symlink
+ for _, l := range revertLinks {
+ err := renameio.Symlink(l.old, l.new)
+ if err != nil {
+ keep = append(keep, l)
+ li.Log.ErrorContext(ctx, "Failed to revert symlink", "old", l.old, "new", l.new, "err", err)
+ }
+ }
+ revertLinks = keep
+ return len(revertLinks) == 0
+ }
+ // revert immediately on error, so caller can ignore revert arg
+ defer func() {
+ if err != nil {
+ revert(ctx)
+ }
+ }()
+
versionDir, err := li.versionDir(version)
if err != nil {
- return trace.Wrap(err)
+ return revert, trace.Wrap(err)
}
// ensure target directories exist before trying to create links
err = os.MkdirAll(li.LinkBinDir, 0755)
if err != nil {
- return trace.Wrap(err)
+ return revert, trace.Wrap(err)
}
err = os.MkdirAll(li.LinkServiceDir, 0755)
if err != nil {
- return trace.Wrap(err)
+ return revert, trace.Wrap(err)
}
// create binary links
+
binDir := filepath.Join(versionDir, "bin")
entries, err := os.ReadDir(binDir)
if err != nil {
- return trace.Errorf("failed to find Teleport binary directory: %w", err)
+ return revert, trace.Errorf("failed to find Teleport binary directory: %w", err)
}
var linked int
for _, entry := range entries {
if entry.IsDir() {
continue
}
- err := renameio.Symlink(filepath.Join(binDir, entry.Name()), filepath.Join(li.LinkBinDir, entry.Name()))
+ oldname := filepath.Join(binDir, entry.Name())
+ newname := filepath.Join(li.LinkBinDir, entry.Name())
+ orig, err := tryLink(oldname, newname)
if err != nil {
- return trace.Wrap(err)
+ return revert, trace.Errorf("failed to create symlink for %s: %w", filepath.Base(oldname), err)
+ }
+ if orig != "" {
+ revertLinks = append(revertLinks, symlink{
+ old: orig,
+ new: newname,
+ })
}
linked++
}
if linked == 0 {
- return trace.Errorf("no binaries available to link")
+ return revert, trace.Errorf("no binaries available to link")
}
// create systemd service link
- service := filepath.Join(versionDir, servicePath)
- err = renameio.Symlink(service, filepath.Join(li.LinkServiceDir, filepath.Base(servicePath)))
+
+ oldname := filepath.Join(versionDir, servicePath)
+ newname := filepath.Join(li.LinkServiceDir, filepath.Base(servicePath))
+ orig, err := tryLink(oldname, newname)
if err != nil {
- return trace.Wrap(err)
+ return revert, trace.Errorf("failed to create symlink for %s: %w", filepath.Base(oldname), err)
}
- return nil
+ if orig != "" {
+ revertLinks = append(revertLinks, symlink{
+ old: orig,
+ new: newname,
+ })
+ }
+ return revert, nil
+}
+
+// tryLink attempts to create a symlink, atomically replacing an existing link if already present.
+// If a non-symlink file or directory exists in newname already, tryLink errors.
+func tryLink(oldname, newname string) (orig string, err error) {
+ orig, err = os.Readlink(newname)
+ if errors.Is(err, os.ErrInvalid) ||
+ errors.Is(err, syscall.EINVAL) { // workaround missing ErrInvalid wrapper
+ // important: do not attempt to replace a non-linked install of Teleport
+ return orig, trace.Errorf("refusing to replace file at %s", newname)
+ }
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ return orig, trace.Wrap(err)
+ }
+ if orig == oldname {
+ return "", nil
+ }
+ err = renameio.Symlink(oldname, newname)
+ if err != nil {
+ return orig, trace.Wrap(err)
+ }
+ return orig, nil
}
// versionDir returns the storage directory for a Teleport version.
diff --git a/lib/autoupdate/agent/installer_test.go b/lib/autoupdate/agent/installer_test.go
index 2602704208855..d4f58f782dc62 100644
--- a/lib/autoupdate/agent/installer_test.go
+++ b/lib/autoupdate/agent/installer_test.go
@@ -196,16 +196,18 @@ func TestLocalInstaller_Link(t *testing.T) {
const version = "new-version"
tests := []struct {
- name string
- dirs []string
- files []string
+ name string
+ installDirs []string
+ installFiles []string
+ existingLinks []string
+ existingFiles []string
- links []string
- errMatch string
+ resultLinks []string
+ errMatch string
}{
{
- name: "present",
- dirs: []string{
+ name: "present with new links",
+ installDirs: []string{
"bin",
"bin/somedir",
"etc",
@@ -213,7 +215,7 @@ func TestLocalInstaller_Link(t *testing.T) {
"etc/systemd/somedir",
"somedir",
},
- files: []string{
+ installFiles: []string{
"bin/teleport",
"bin/tsh",
"bin/tbot",
@@ -221,7 +223,7 @@ func TestLocalInstaller_Link(t *testing.T) {
"README",
},
- links: []string{
+ resultLinks: []string{
"bin/teleport",
"bin/tsh",
"bin/tbot",
@@ -229,15 +231,102 @@ func TestLocalInstaller_Link(t *testing.T) {
},
},
{
- name: "no links",
- files: []string{"README"},
- dirs: []string{"bin"},
+ name: "present with existing links",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "etc",
+ "etc/systemd",
+ "etc/systemd/somedir",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ existingLinks: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ "lib/systemd/system/teleport.service",
+ },
+
+ resultLinks: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ "lib/systemd/system/teleport.service",
+ },
+ },
+ {
+ name: "conflicting systemd files",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "etc",
+ "etc/systemd",
+ "etc/systemd/somedir",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ existingLinks: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ },
+ existingFiles: []string{
+ "lib/systemd/system/teleport.service",
+ },
+
+ errMatch: "refusing",
+ },
+ {
+ name: "conflicting bin files",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "etc",
+ "etc/systemd",
+ "etc/systemd/somedir",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ existingLinks: []string{
+ "bin/teleport",
+ "bin/tbot",
+ "lib/systemd/system/teleport.service",
+ },
+ existingFiles: []string{
+ "bin/tsh",
+ },
+
+ errMatch: "refusing",
+ },
+ {
+ name: "no links",
+ installFiles: []string{"README"},
+ installDirs: []string{"bin"},
errMatch: "no binaries",
},
{
- name: "no bin directory",
- files: []string{"README"},
+ name: "no bin directory",
+ installFiles: []string{"README"},
errMatch: "binary directory",
},
@@ -251,16 +340,30 @@ func TestLocalInstaller_Link(t *testing.T) {
err := os.MkdirAll(versionDir, 0o755)
require.NoError(t, err)
- for _, d := range tt.dirs {
+ // setup files in version directory
+ for _, d := range tt.installDirs {
err := os.Mkdir(filepath.Join(versionDir, d), os.ModePerm)
require.NoError(t, err)
}
- for _, n := range tt.files {
+ for _, n := range tt.installFiles {
err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), os.ModePerm)
require.NoError(t, err)
}
+ // setup files in system links directory
linkDir := t.TempDir()
+ for _, n := range tt.existingLinks {
+ err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm)
+ require.NoError(t, err)
+ err = os.Symlink(filepath.Base(n)+".old", filepath.Join(linkDir, n))
+ require.NoError(t, err)
+ }
+ for _, n := range tt.existingFiles {
+ err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm)
+ require.NoError(t, err)
+ err = os.WriteFile(filepath.Join(linkDir, n), []byte(filepath.Base(n)), os.ModePerm)
+ require.NoError(t, err)
+ }
installer := &LocalInstaller{
InstallDir: versionsDir,
@@ -269,19 +372,50 @@ func TestLocalInstaller_Link(t *testing.T) {
Log: slog.Default(),
}
ctx := context.Background()
- err = installer.Link(ctx, version)
+ revert, err := installer.Link(ctx, version)
if tt.errMatch != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMatch)
+
+ // verify automatic revert
+ for _, link := range tt.existingLinks {
+ v, err := os.Readlink(filepath.Join(linkDir, link))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(link)+".old", v)
+ }
+ for _, n := range tt.existingFiles {
+ v, err := os.ReadFile(filepath.Join(linkDir, n))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(n), string(v))
+ }
+
+ // ensure revert still succeeds
+ ok := revert(ctx)
+ require.True(t, ok)
return
}
require.NoError(t, err)
- for _, link := range tt.links {
+ // verify links
+ for _, link := range tt.resultLinks {
v, err := os.ReadFile(filepath.Join(linkDir, link))
require.NoError(t, err)
require.Equal(t, filepath.Base(link), string(v))
}
+
+ // verify manual revert
+ ok := revert(ctx)
+ require.True(t, ok)
+ for _, link := range tt.existingLinks {
+ v, err := os.Readlink(filepath.Join(linkDir, link))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(link)+".old", v)
+ }
+ for _, n := range tt.existingFiles {
+ v, err := os.ReadFile(filepath.Join(linkDir, n))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(n), string(v))
+ }
})
}
}
@@ -397,7 +531,7 @@ func TestLocalInstaller_Remove(t *testing.T) {
ctx := context.Background()
if tt.linkedVersion != "" {
- err = installer.Link(ctx, tt.linkedVersion)
+ _, err = installer.Link(ctx, tt.linkedVersion)
require.NoError(t, err)
}
err = installer.Remove(ctx, tt.removeVersion)
diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go
new file mode 100644
index 0000000000000..eba70aa56a690
--- /dev/null
+++ b/lib/autoupdate/agent/process.go
@@ -0,0 +1,179 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "log/slog"
+ "os"
+ "os/exec"
+
+ "github.com/gravitational/trace"
+)
+
+// SystemdService manages a Teleport systemd service.
+type SystemdService struct {
+ // ServiceName specifies the systemd service name.
+ ServiceName string
+ // Log contains a logger.
+ Log *slog.Logger
+}
+
+// Reload a systemd service.
+// Attempts a graceful reload before a hard restart.
+// See Process interface for more details.
+func (s SystemdService) Reload(ctx context.Context) error {
+ if err := s.checkSystem(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+ // Command error codes < 0 indicate that we are unable to run the command.
+ // Errors from s.systemctl are logged along with stderr and stdout (debug only).
+
+ // If the service is not running, return ErrNotNeeded.
+ // Note systemctl reload returns an error if the unit is not active, and
+ // try-reload-or-restart is too recent of an addition for centos7.
+ code := s.systemctl(ctx, slog.LevelDebug, "is-active", "--quiet", s.ServiceName)
+ switch {
+ case code < 0:
+ return trace.Errorf("unable to determine if systemd service is active")
+ case code > 0:
+ s.Log.WarnContext(ctx, "Teleport systemd service not running.")
+ return trace.Wrap(ErrNotNeeded)
+ }
+ // Attempt graceful reload of running service.
+ code = s.systemctl(ctx, slog.LevelError, "reload", s.ServiceName)
+ switch {
+ case code < 0:
+ return trace.Errorf("unable to attempt reload of Teleport systemd service")
+ case code > 0:
+ // Graceful reload fails, try hard restart.
+ code = s.systemctl(ctx, slog.LevelError, "try-restart", s.ServiceName)
+ if code != 0 {
+ return trace.Errorf("hard restart of Teleport systemd service failed")
+ }
+ s.Log.WarnContext(ctx, "Teleport ungracefully restarted. Connections potentially dropped.")
+ default:
+ s.Log.InfoContext(ctx, "Teleport gracefully reloaded.")
+ }
+
+ // TODO(sclevine): Ensure restart was successful and verify healthcheck.
+
+ return nil
+}
+
+// Sync systemd service configuration by running systemctl daemon-reload.
+// See Process interface for more details.
+func (s SystemdService) Sync(ctx context.Context) error {
+ if err := s.checkSystem(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+ code := s.systemctl(ctx, slog.LevelError, "daemon-reload")
+ if code != 0 {
+ return trace.Errorf("unable to reload systemd configuration")
+ }
+ return nil
+}
+
+// checkSystem returns an error if the system is not compatible with this process manager.
+func (s SystemdService) checkSystem(ctx context.Context) error {
+ _, err := os.Stat("/run/systemd/system")
+ if errors.Is(err, os.ErrNotExist) {
+ s.Log.ErrorContext(ctx, "This system does not support systemd, which is required by the updater.")
+ return trace.Wrap(ErrNotSupported)
+ }
+ return trace.Wrap(err)
+}
+
+// systemctl returns a systemctl subcommand, converting the output to logs.
+// Output sent to stdout is logged at debug level.
+// Output sent to stderr is logged at the level specified by errLevel.
+func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args ...string) int {
+ cmd := exec.CommandContext(ctx, "systemctl", args...)
+ stderr := &lineLogger{ctx: ctx, log: s.Log, level: errLevel}
+ stdout := &lineLogger{ctx: ctx, log: s.Log, level: slog.LevelDebug}
+ cmd.Stderr = stderr
+ cmd.Stdout = stdout
+ err := cmd.Run()
+ stderr.Flush()
+ stdout.Flush()
+ code := cmd.ProcessState.ExitCode()
+
+ // Treat out-of-range exit code (255) as an error executing the command.
+ // This allows callers to treat codes that are more likely OS-related as execution errors
+ // instead of intentionally returned error codes.
+ if code == 255 {
+ code = -1
+ }
+ if err != nil {
+ s.Log.Log(ctx, errLevel, "Failed to run systemctl.",
+ "args", args,
+ "code", code,
+ "error", err)
+ }
+ return code
+}
+
+// lineLogger logs each line written to it.
+type lineLogger struct {
+ ctx context.Context
+ log *slog.Logger
+ level slog.Level
+
+ last bytes.Buffer
+}
+
+func (w *lineLogger) Write(p []byte) (n int, err error) {
+ lines := bytes.Split(p, []byte("\n"))
+ // Finish writing line
+ if len(lines) > 0 {
+ n, err = w.last.Write(lines[0])
+ lines = lines[1:]
+ }
+ // Quit if no newline
+ if len(lines) == 0 || err != nil {
+ return n, trace.Wrap(err)
+ }
+
+ // Newline found, log line
+ w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant
+ n += 1
+ w.last.Reset()
+
+ // Log lines that are already newline-terminated
+ for _, line := range lines[:len(lines)-1] {
+ w.log.Log(w.ctx, w.level, string(line)) //nolint:sloglint // msg cannot be constant
+ n += len(line) + 1
+ }
+
+ // Store remaining line non-newline-terminated line.
+ n2, err := w.last.Write(lines[len(lines)-1])
+ n += n2
+ return n, trace.Wrap(err)
+}
+
+// Flush logs any trailing bytes that were never terminated with a newline.
+func (w *lineLogger) Flush() {
+ if w.last.Len() == 0 {
+ return
+ }
+ w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant
+ w.last.Reset()
+}
diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go
new file mode 100644
index 0000000000000..5ffa70dd0091e
--- /dev/null
+++ b/lib/autoupdate/agent/process_test.go
@@ -0,0 +1,71 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "context"
+ "log/slog"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestLineLogger(t *testing.T) {
+ t.Parallel()
+
+ out := &bytes.Buffer{}
+ ll := lineLogger{
+ ctx: context.Background(),
+ log: slog.New(slog.NewTextHandler(out,
+ &slog.HandlerOptions{ReplaceAttr: msgOnly},
+ )),
+ }
+
+ for _, e := range []struct {
+ v string
+ n int
+ }{
+ {v: "", n: 0},
+ {v: "a", n: 1},
+ {v: "b\n", n: 2},
+ {v: "c\nd", n: 3},
+ {v: "e\nf\ng", n: 5},
+ {v: "h", n: 1},
+ {v: "", n: 0},
+ {v: "\n", n: 1},
+ {v: "i\n", n: 2},
+ {v: "j", n: 1},
+ } {
+ n, err := ll.Write([]byte(e.v))
+ require.NoError(t, err)
+ require.Equal(t, e.n, n)
+ }
+ require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\n", out.String())
+ ll.Flush()
+ require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\nmsg=j\n", out.String())
+}
+
+func msgOnly(_ []string, a slog.Attr) slog.Attr {
+ switch a.Key {
+ case "time", "level":
+ return slog.Attr{}
+ }
+ return slog.Attr{Key: a.Key, Value: a.Value}
+}
diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go
index 7071f16e42d15..b82c3c6d419cb 100644
--- a/lib/autoupdate/agent/updater.go
+++ b/lib/autoupdate/agent/updater.go
@@ -135,6 +135,10 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) {
ReservedFreeTmpDisk: reservedFreeDisk,
ReservedFreeInstallDisk: reservedFreeDisk,
},
+ Process: &SystemdService{
+ ServiceName: "teleport.service",
+ Log: cfg.Log,
+ },
}, nil
}
@@ -166,26 +170,58 @@ type Updater struct {
ConfigPath string
// Installer manages installations of the Teleport agent.
Installer Installer
+ // Process manages a running instance of Teleport.
+ Process Process
}
// Installer provides an API for installing Teleport agents.
type Installer interface {
// Install the Teleport agent at version from the download template.
- // This function must be idempotent.
+ // Install must be idempotent.
Install(ctx context.Context, version, template string, flags InstallFlags) error
- // Link the Teleport agent at version into the system location.
- // This function must be idempotent.
- Link(ctx context.Context, version string) error
+ // Link the Teleport agent at the specified version into the system location.
+ // The revert function must restore the previous linking, returning false on any failure.
+ // Link must be idempotent.
+ // Link's revert function must be idempotent.
+ Link(ctx context.Context, version string) (revert func(context.Context) bool, err error)
// List the installed versions of Teleport.
List(ctx context.Context) (versions []string, err error)
// Remove the Teleport agent at version.
- // This function must be idempotent.
+ // Must return ErrLinked if unable to remove due to being linked.
+ // Remove must be idempotent.
Remove(ctx context.Context, version string) error
}
+var (
+ // ErrLinked is returned when a linked version cannot be operated on.
+ ErrLinked = errors.New("version is linked")
+ // ErrNotNeeded is returned when the operation is not needed.
+ ErrNotNeeded = errors.New("not needed")
+ // ErrNotSupported is returned when the operation is not supported on the platform.
+ ErrNotSupported = errors.New("not supported on this platform")
+)
+
+// Process provides an API for interacting with a running Teleport process.
+type Process interface {
+ // Reload must reload the Teleport process as gracefully as possible.
+ // If the process is not healthy after reloading, Reload must return an error.
+ // If the process did not require reloading, Reload must return ErrNotNeeded.
+ // E.g., if the process is not enabled, or it was already reloaded after the last Sync.
+ // If the type implementing Process does not support the system process manager,
+ // Reload must return ErrNotSupported.
+ Reload(ctx context.Context) error
+ // Sync must validate and synchronize process configuration.
+ // After the linked Teleport installation is changed, failure to call Sync without
+ // error before Reload may result in undefined behavior.
+ // If the type implementing Process does not support the system process manager,
+ // Sync must return ErrNotSupported.
+ Sync(ctx context.Context) error
+}
+
// InstallFlags sets flags for the Teleport installation
type InstallFlags int
+// TODO(sclevine): add flags for need_restart and selinux config
const (
// FlagEnterprise installs enterprise Teleport
FlagEnterprise InstallFlags = 1 << iota
@@ -215,30 +251,20 @@ type OverrideConfig struct {
// This function is idempotent.
func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error {
// Read configuration from update.yaml and override any new values passed as flags.
- cfg, err := u.readConfig(u.ConfigPath)
+ cfg, err := readConfig(u.ConfigPath)
if err != nil {
return trace.Errorf("failed to read %s: %w", updateConfigName, err)
}
- if override.Proxy != "" {
- cfg.Spec.Proxy = override.Proxy
- }
- if override.Group != "" {
- cfg.Spec.Group = override.Group
- }
- if override.URLTemplate != "" {
- cfg.Spec.URLTemplate = override.URLTemplate
- }
- cfg.Spec.Enabled = true
- if err := validateUpdatesSpec(&cfg.Spec); err != nil {
+ if err := validateConfigSpec(&cfg.Spec, override); err != nil {
return trace.Wrap(err)
}
// Lookup target version from the proxy.
+
addr, err := libutils.ParseAddr(cfg.Spec.Proxy)
if err != nil {
return trace.Errorf("failed to parse proxy server address: %w", err)
}
-
desiredVersion := override.ForceVersion
var flags InstallFlags
if desiredVersion == "" {
@@ -278,7 +304,9 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error {
u.Log.WarnContext(ctx, "Failed to remove backup version of Teleport before new install.", "error", err)
}
}
- // If the active version and target don't match, kick off upgrade.
+
+ // Install the desired version (or validate existing installation)
+
template := cfg.Spec.URLTemplate
if template == "" {
template = cdnURITemplate
@@ -287,14 +315,55 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error {
if err != nil {
return trace.Errorf("failed to install: %w", err)
}
- err = u.Installer.Link(ctx, desiredVersion)
+ revert, err := u.Installer.Link(ctx, desiredVersion)
if err != nil {
return trace.Errorf("failed to link: %w", err)
}
+
+ // If we fail to revert after this point, the next update/enable will
+ // fix the link to restore the active version.
+
+ // Sync process configuration after linking.
+
+ if err := u.Process.Sync(ctx); err != nil {
+ if errors.Is(err, context.Canceled) {
+ return trace.Errorf("sync canceled")
+ }
+ // If sync fails, we may have left the host in a bad state, so we revert linking and re-Sync.
+ u.Log.ErrorContext(ctx, "Reverting symlinks due to invalid configuration.")
+ if ok := revert(ctx); !ok {
+ u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks. Installation likely broken.")
+ } else if err := u.Process.Sync(ctx); err != nil {
+ u.Log.ErrorContext(ctx, "Failed to sync configuration after failed restart.", "error", err)
+ }
+ u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.")
+
+ return trace.Errorf("failed to validate configuration for new version %q of Teleport: %w", desiredVersion, err)
+ }
+
+ // Restart Teleport if necessary.
+
if cfg.Status.ActiveVersion != desiredVersion {
+ u.Log.InfoContext(ctx, "Target version successfully installed.", "version", desiredVersion)
+ if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) {
+ if errors.Is(err, context.Canceled) {
+ return trace.Errorf("reload canceled")
+ }
+ // If reloading Teleport at the new version fails, revert, resync, and reload.
+ u.Log.ErrorContext(ctx, "Reverting symlinks due to failed restart.")
+ if ok := revert(ctx); !ok {
+ u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks to older version. Installation likely broken.")
+ } else if err := u.Process.Sync(ctx); err != nil {
+ u.Log.ErrorContext(ctx, "Invalid configuration found after reverting Teleport to older version. Installation likely broken.", "error", err)
+ } else if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) {
+ u.Log.ErrorContext(ctx, "Failed to revert Teleport to older version. Installation likely broken.", "error", err)
+ }
+ u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.")
+
+ return trace.Errorf("failed to start new version %q of Teleport: %w", desiredVersion, err)
+ }
cfg.Status.BackupVersion = cfg.Status.ActiveVersion
cfg.Status.ActiveVersion = desiredVersion
- u.Log.InfoContext(ctx, "Target version successfully installed.", "version", desiredVersion)
} else {
u.Log.InfoContext(ctx, "Target version successfully validated.", "version", desiredVersion)
}
@@ -302,6 +371,8 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error {
u.Log.InfoContext(ctx, "Backup version set.", "version", v)
}
+ // Check if manual cleanup might be needed.
+
versions, err := u.Installer.List(ctx)
if err != nil {
return trace.Errorf("failed to list installed versions: %w", err)
@@ -311,29 +382,19 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error {
}
// Always write the configuration file if enable succeeds.
- if err := u.writeConfig(u.ConfigPath, cfg); err != nil {
+
+ cfg.Spec.Enabled = true
+ if err := writeConfig(u.ConfigPath, cfg); err != nil {
return trace.Errorf("failed to write %s: %w", updateConfigName, err)
}
u.Log.InfoContext(ctx, "Configuration updated.")
return nil
}
-func validateUpdatesSpec(spec *UpdateSpec) error {
- if spec.URLTemplate != "" &&
- !strings.HasPrefix(strings.ToLower(spec.URLTemplate), "https://") {
- return trace.Errorf("Teleport download URL must use TLS (https://)")
- }
-
- if spec.Proxy == "" {
- return trace.Errorf("Teleport proxy URL must be specified with --proxy or present in %s", updateConfigName)
- }
- return nil
-}
-
// Disable disables agent auto-updates.
// This function is idempotent.
func (u *Updater) Disable(ctx context.Context) error {
- cfg, err := u.readConfig(u.ConfigPath)
+ cfg, err := readConfig(u.ConfigPath)
if err != nil {
return trace.Errorf("failed to read %s: %w", updateConfigName, err)
}
@@ -342,14 +403,14 @@ func (u *Updater) Disable(ctx context.Context) error {
return nil
}
cfg.Spec.Enabled = false
- if err := u.writeConfig(u.ConfigPath, cfg); err != nil {
+ if err := writeConfig(u.ConfigPath, cfg); err != nil {
return trace.Errorf("failed to write %s: %w", updateConfigName, err)
}
return nil
}
// readConfig reads UpdateConfig from a file.
-func (*Updater) readConfig(path string) (*UpdateConfig, error) {
+func readConfig(path string) (*UpdateConfig, error) {
f, err := os.Open(path)
if errors.Is(err, fs.ErrNotExist) {
return &UpdateConfig{
@@ -375,7 +436,7 @@ func (*Updater) readConfig(path string) (*UpdateConfig, error) {
}
// writeConfig writes UpdateConfig to a file atomically, ensuring the file cannot be corrupted.
-func (*Updater) writeConfig(filename string, cfg *UpdateConfig) error {
+func writeConfig(filename string, cfg *UpdateConfig) error {
opts := []renameio.Option{
renameio.WithPermissions(0755),
renameio.WithExistingPermissions(),
@@ -391,3 +452,23 @@ func (*Updater) writeConfig(filename string, cfg *UpdateConfig) error {
}
return trace.Wrap(t.CloseAtomicallyReplace())
}
+
+func validateConfigSpec(spec *UpdateSpec, override OverrideConfig) error {
+ if override.Proxy != "" {
+ spec.Proxy = override.Proxy
+ }
+ if override.Group != "" {
+ spec.Group = override.Group
+ }
+ if override.URLTemplate != "" {
+ spec.URLTemplate = override.URLTemplate
+ }
+ if spec.URLTemplate != "" &&
+ !strings.HasPrefix(strings.ToLower(spec.URLTemplate), "https://") {
+ return trace.Errorf("Teleport download URL must use TLS (https://)")
+ }
+ if spec.Proxy == "" {
+ return trace.Errorf("Teleport proxy URL must be specified with --proxy or present in %s", updateConfigName)
+ }
+ return nil
+}
diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go
index e817851fed1f7..8cefd3a59e3e7 100644
--- a/lib/autoupdate/agent/updater_test.go
+++ b/lib/autoupdate/agent/updater_test.go
@@ -132,11 +132,16 @@ func TestUpdater_Enable(t *testing.T) {
userCfg OverrideConfig
installErr error
flags InstallFlags
+ syncErr error
+ reloadErr error
removedVersion string
installedVersion string
installedTemplate string
requestGroup string
+ syncCalls int
+ reloadCalls int
+ revertCalls int
errMatch string
}{
{
@@ -152,9 +157,12 @@ func TestUpdater_Enable(t *testing.T) {
ActiveVersion: "old-version",
},
},
+
installedVersion: "16.3.0",
installedTemplate: "https://example.com",
requestGroup: "group",
+ syncCalls: 1,
+ reloadCalls: 1,
},
{
name: "config from user",
@@ -174,8 +182,11 @@ func TestUpdater_Enable(t *testing.T) {
URLTemplate: "https://example.com/new",
ForceVersion: "new-version",
},
+
installedVersion: "new-version",
installedTemplate: "https://example.com/new",
+ syncCalls: 1,
+ reloadCalls: 1,
},
{
name: "already enabled",
@@ -189,8 +200,11 @@ func TestUpdater_Enable(t *testing.T) {
ActiveVersion: "old-version",
},
},
+
installedVersion: "16.3.0",
installedTemplate: cdnURITemplate,
+ syncCalls: 1,
+ reloadCalls: 1,
},
{
name: "insecure URL",
@@ -201,6 +215,7 @@ func TestUpdater_Enable(t *testing.T) {
URLTemplate: "http://example.com",
},
},
+
errMatch: "URL must use TLS",
},
{
@@ -213,7 +228,8 @@ func TestUpdater_Enable(t *testing.T) {
},
},
installErr: errors.New("install error"),
- errMatch: "install error",
+
+ errMatch: "install error",
},
{
name: "version already installed",
@@ -224,8 +240,11 @@ func TestUpdater_Enable(t *testing.T) {
ActiveVersion: "16.3.0",
},
},
+
installedVersion: "16.3.0",
installedTemplate: cdnURITemplate,
+ syncCalls: 1,
+ reloadCalls: 0,
},
{
name: "backup version removed on install",
@@ -237,9 +256,12 @@ func TestUpdater_Enable(t *testing.T) {
BackupVersion: "backup-version",
},
},
+
installedVersion: "16.3.0",
installedTemplate: cdnURITemplate,
removedVersion: "backup-version",
+ syncCalls: 1,
+ reloadCalls: 1,
},
{
name: "backup version kept for validation",
@@ -251,26 +273,56 @@ func TestUpdater_Enable(t *testing.T) {
BackupVersion: "backup-version",
},
},
+
installedVersion: "16.3.0",
installedTemplate: cdnURITemplate,
removedVersion: "",
+ syncCalls: 1,
+ reloadCalls: 0,
},
{
- name: "config does not exist",
+ name: "config does not exist",
+
installedVersion: "16.3.0",
installedTemplate: cdnURITemplate,
+ syncCalls: 1,
+ reloadCalls: 1,
},
{
name: "FIPS and Enterprise flags",
flags: FlagEnterprise | FlagFIPS,
installedVersion: "16.3.0",
installedTemplate: cdnURITemplate,
+ syncCalls: 1,
+ reloadCalls: 1,
},
{
name: "invalid metadata",
cfg: &UpdateConfig{},
errMatch: "invalid",
},
+ {
+ name: "sync fails",
+ syncErr: errors.New("sync error"),
+
+ installedVersion: "16.3.0",
+ installedTemplate: cdnURITemplate,
+ syncCalls: 2,
+ reloadCalls: 0,
+ revertCalls: 1,
+ errMatch: "sync error",
+ },
+ {
+ name: "reload fails",
+ reloadErr: errors.New("reload error"),
+
+ installedVersion: "16.3.0",
+ installedTemplate: cdnURITemplate,
+ syncCalls: 2,
+ reloadCalls: 2,
+ revertCalls: 1,
+ errMatch: "reload error",
+ },
}
for _, tt := range tests {
@@ -320,6 +372,7 @@ func TestUpdater_Enable(t *testing.T) {
linkedVersion string
removedVersion string
installedFlags InstallFlags
+ revertCalls int
)
updater.Installer = &testInstaller{
FuncInstall: func(_ context.Context, version, template string, flags InstallFlags) error {
@@ -328,9 +381,12 @@ func TestUpdater_Enable(t *testing.T) {
installedFlags = flags
return tt.installErr
},
- FuncLink: func(_ context.Context, version string) error {
+ FuncLink: func(_ context.Context, version string) (revert func(context.Context) bool, err error) {
linkedVersion = version
- return nil
+ return func(_ context.Context) bool {
+ revertCalls++
+ return true
+ }, nil
},
FuncList: func(_ context.Context) (versions []string, err error) {
return []string{"old"}, nil
@@ -340,6 +396,20 @@ func TestUpdater_Enable(t *testing.T) {
return nil
},
}
+ var (
+ syncCalls int
+ reloadCalls int
+ )
+ updater.Process = &testProcess{
+ FuncSync: func(_ context.Context) error {
+ syncCalls++
+ return tt.syncErr
+ },
+ FuncReload: func(_ context.Context) error {
+ reloadCalls++
+ return tt.reloadErr
+ },
+ }
ctx := context.Background()
err = updater.Enable(ctx, tt.userCfg)
@@ -355,6 +425,9 @@ func TestUpdater_Enable(t *testing.T) {
require.Equal(t, tt.removedVersion, removedVersion)
require.Equal(t, tt.flags, installedFlags)
require.Equal(t, tt.requestGroup, requestedGroup)
+ require.Equal(t, tt.syncCalls, syncCalls)
+ require.Equal(t, tt.reloadCalls, reloadCalls)
+ require.Equal(t, tt.revertCalls, revertCalls)
data, err := os.ReadFile(cfgPath)
require.NoError(t, err)
@@ -377,7 +450,7 @@ func blankTestAddr(s []byte) []byte {
type testInstaller struct {
FuncInstall func(ctx context.Context, version, template string, flags InstallFlags) error
FuncRemove func(ctx context.Context, version string) error
- FuncLink func(ctx context.Context, version string) error
+ FuncLink func(ctx context.Context, version string) (revert func(context.Context) bool, err error)
FuncList func(ctx context.Context) (versions []string, err error)
}
@@ -389,10 +462,23 @@ func (ti *testInstaller) Remove(ctx context.Context, version string) error {
return ti.FuncRemove(ctx, version)
}
-func (ti *testInstaller) Link(ctx context.Context, version string) error {
+func (ti *testInstaller) Link(ctx context.Context, version string) (revert func(context.Context) bool, err error) {
return ti.FuncLink(ctx, version)
}
func (ti *testInstaller) List(ctx context.Context) (versions []string, err error) {
return ti.FuncList(ctx)
}
+
+type testProcess struct {
+ FuncReload func(ctx context.Context) error
+ FuncSync func(ctx context.Context) error
+}
+
+func (tp *testProcess) Reload(ctx context.Context) error {
+ return tp.FuncReload(ctx)
+}
+
+func (tp *testProcess) Sync(ctx context.Context) error {
+ return tp.FuncSync(ctx)
+}
diff --git a/lib/utils/unpack.go b/lib/utils/unpack.go
index 78b111daf8992..14b213f08a173 100644
--- a/lib/utils/unpack.go
+++ b/lib/utils/unpack.go
@@ -50,7 +50,8 @@ func Extract(r io.Reader, dir string, paths ...ExtractPath) error {
} else if err != nil {
return trace.Wrap(err)
}
- if ok := filterHeader(header, paths); !ok {
+ dirMode, ok := filterHeader(header, paths)
+ if !ok {
continue
}
err = sanitizeTarPath(header, dir)
@@ -58,7 +59,7 @@ func Extract(r io.Reader, dir string, paths ...ExtractPath) error {
return trace.Wrap(err)
}
- if err := extractFile(tarball, header, dir); err != nil {
+ if err := extractFile(tarball, header, dir, dirMode); err != nil {
return trace.Wrap(err)
}
}
@@ -74,11 +75,15 @@ type ExtractPath struct {
Src, Dst string
// Skip extracting the Src path and ignore Dst.
Skip bool
+ // DirMode is the file mode for implicit parent directories in Dst.
+ DirMode os.FileMode
}
// filterHeader modifies the tar header by filtering it through the ExtractPaths.
// filterHeader returns false if the tar header should be skipped.
-func filterHeader(hdr *tar.Header, paths []ExtractPath) (include bool) {
+// If no paths are provided, filterHeader assumes the header should be included, and sets
+// the mode for implicit parent directories to teleport.DirMaskSharedGroup.
+func filterHeader(hdr *tar.Header, paths []ExtractPath) (dirMode os.FileMode, include bool) {
name := path.Clean(hdr.Name)
for _, p := range paths {
src := path.Clean(p.Src)
@@ -98,14 +103,14 @@ func filterHeader(hdr *tar.Header, paths []ExtractPath) (include bool) {
dst += "/" // tar directory headers end in /
}
hdr.Name = dst
- return !p.Skip
+ return p.DirMode, !p.Skip
default:
// If name is a file, then
// if src is an exact match to the file name, assume src is a file and write directly to dst,
// otherwise, assume src is a directory prefix, and replace that prefix with dst.
if src == name {
hdr.Name = path.Clean(p.Dst)
- return !p.Skip
+ return p.DirMode, !p.Skip
}
if src != "/" {
src += "/" // ensure HasPrefix does not match partial names
@@ -114,26 +119,26 @@ func filterHeader(hdr *tar.Header, paths []ExtractPath) (include bool) {
continue
}
hdr.Name = path.Join(p.Dst, strings.TrimPrefix(name, src))
- return !p.Skip
+ return p.DirMode, !p.Skip
}
}
- return len(paths) == 0
+ return teleport.DirMaskSharedGroup, len(paths) == 0
}
// extractFile extracts a single file or directory from tarball into dir.
// Uses header to determine the type of item to create
// Based on https://github.com/mholt/archiver
-func extractFile(tarball *tar.Reader, header *tar.Header, dir string) error {
+func extractFile(tarball *tar.Reader, header *tar.Header, dir string, dirMode os.FileMode) error {
switch header.Typeflag {
case tar.TypeDir:
- return withDir(filepath.Join(dir, header.Name), nil)
+ return withDir(filepath.Join(dir, header.Name), dirMode, nil)
case tar.TypeBlock, tar.TypeChar, tar.TypeReg, tar.TypeFifo:
- return writeFile(filepath.Join(dir, header.Name), tarball, header.FileInfo().Mode())
+ return writeFile(filepath.Join(dir, header.Name), tarball, header.FileInfo().Mode(), dirMode)
case tar.TypeLink:
- return writeHardLink(filepath.Join(dir, header.Name), filepath.Join(dir, header.Linkname))
+ return writeHardLink(filepath.Join(dir, header.Name), filepath.Join(dir, header.Linkname), dirMode)
case tar.TypeSymlink:
- return writeSymbolicLink(filepath.Join(dir, header.Name), header.Linkname)
+ return writeSymbolicLink(filepath.Join(dir, header.Name), header.Linkname, dirMode)
default:
log.Warnf("Unsupported type flag %v for %v.", header.Typeflag, header.Name)
}
@@ -168,8 +173,8 @@ func sanitizeTarPath(header *tar.Header, dir string) error {
return nil
}
-func writeFile(path string, r io.Reader, mode os.FileMode) error {
- err := withDir(path, func() error {
+func writeFile(path string, r io.Reader, mode, dirMode os.FileMode) error {
+ err := withDir(path, dirMode, func() error {
// Create file only if it does not exist to prevent overwriting existing
// files (like session recordings).
out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode)
@@ -182,24 +187,24 @@ func writeFile(path string, r io.Reader, mode os.FileMode) error {
return trace.Wrap(err)
}
-func writeSymbolicLink(path string, target string) error {
- err := withDir(path, func() error {
+func writeSymbolicLink(path, target string, dirMode os.FileMode) error {
+ err := withDir(path, dirMode, func() error {
err := os.Symlink(target, path)
return trace.ConvertSystemError(err)
})
return trace.Wrap(err)
}
-func writeHardLink(path string, target string) error {
- err := withDir(path, func() error {
+func writeHardLink(path, target string, dirMode os.FileMode) error {
+ err := withDir(path, dirMode, func() error {
err := os.Link(target, path)
return trace.ConvertSystemError(err)
})
return trace.Wrap(err)
}
-func withDir(path string, fn func() error) error {
- err := os.MkdirAll(filepath.Dir(path), teleport.DirMaskSharedGroup)
+func withDir(path string, mode os.FileMode, fn func() error) error {
+ err := os.MkdirAll(filepath.Dir(path), mode)
if err != nil {
return trace.ConvertSystemError(err)
}
diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go
index 300da6736471a..2adce83a1877c 100644
--- a/tool/teleport-update/main.go
+++ b/tool/teleport-update/main.go
@@ -61,6 +61,8 @@ const (
versionsDirName = "versions"
// lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution.
lockFileName = ".lock"
+ // defaultLinkDir is the default location where Teleport binaries and services are linked.
+ defaultLinkDir = "/usr/local"
)
var plog = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentUpdater)
@@ -98,7 +100,7 @@ func Run(args []string) error {
app.Flag("log-format", "Controls the format of output logs. Can be `json` or `text`. Defaults to `text`.").
Default(libutils.LogFormatText).EnumVar(&ccfg.LogFormat, libutils.LogFormatJSON, libutils.LogFormatText)
app.Flag("link-dir", "Directory to create system symlinks to binaries and services.").
- Default(filepath.Join("usr", "local")).Hidden().StringVar(&ccfg.LinkDir)
+ Default(defaultLinkDir).Hidden().StringVar(&ccfg.LinkDir)
app.HelpFlag.Short('h')