diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go
new file mode 100644
index 0000000000000..6664fcaa94131
--- /dev/null
+++ b/lib/autoupdate/agent/config.go
@@ -0,0 +1,122 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "text/template"
+
+ "github.com/google/renameio/v2"
+ "github.com/gravitational/trace"
+)
+
+const (
+ updateServiceTemplate = `# teleport-update
+[Unit]
+Description=Teleport auto-update service
+
+[Service]
+Type=oneshot
+ExecStart={{.LinkDir}}/bin/teleport-update update
+`
+ updateTimerTemplate = `# teleport-update
+[Unit]
+Description=Teleport auto-update timer unit
+
+[Timer]
+OnActiveSec=1m
+OnUnitActiveSec=5m
+RandomizedDelaySec=1m
+
+[Install]
+WantedBy=teleport.service
+`
+)
+
+func Setup(ctx context.Context, log *slog.Logger, linkDir, dataDir string) error {
+ err := writeConfigFiles(linkDir, dataDir)
+ if err != nil {
+ return trace.Errorf("failed to write teleport-update systemd config files: %w", err)
+ }
+ svc := &SystemdService{
+ ServiceName: "teleport-update.timer",
+ Log: log,
+ }
+ err = svc.Sync(ctx)
+ if errors.Is(err, ErrNotSupported) {
+ log.WarnContext(ctx, "Not enabling systemd service because systemd is not running.")
+ return nil
+ }
+ if err != nil {
+ return trace.Errorf("failed to sync systemd config: %w", err)
+ }
+ if err := svc.Enable(ctx, true); err != nil {
+ return trace.Errorf("failed to enable teleport-update systemd timer: %w", err)
+ }
+ return nil
+}
+
+func writeConfigFiles(linkDir, dataDir string) error {
+ // TODO(sclevine): revert on failure
+
+ servicePath := filepath.Join(linkDir, serviceDir, updateServiceName)
+ err := writeTemplate(servicePath, updateServiceTemplate, linkDir, dataDir)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ timerPath := filepath.Join(linkDir, serviceDir, updateTimerName)
+ err = writeTemplate(timerPath, updateTimerTemplate, linkDir, dataDir)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ return nil
+}
+
+func writeTemplate(path, t, linkDir, dataDir string) error {
+ if err := os.MkdirAll(filepath.Dir(path), systemDirMode); err != nil {
+ return trace.Wrap(err)
+ }
+ opts := []renameio.Option{
+ renameio.WithPermissions(configFileMode),
+ renameio.WithExistingPermissions(),
+ }
+ f, err := renameio.NewPendingFile(path, opts...)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer f.Cleanup()
+
+ tmpl, err := template.New(filepath.Base(path)).Parse(t)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ err = tmpl.Execute(f, struct {
+ LinkDir string
+ DataDir string
+ }{linkDir, dataDir})
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return trace.Wrap(f.CloseAtomicallyReplace())
+}
diff --git a/lib/autoupdate/agent/config_test.go b/lib/autoupdate/agent/config_test.go
new file mode 100644
index 0000000000000..16cbdb5374fb6
--- /dev/null
+++ b/lib/autoupdate/agent/config_test.go
@@ -0,0 +1,65 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ libdefaults "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/utils/golden"
+)
+
+func TestWriteConfigFiles(t *testing.T) {
+ t.Parallel()
+ linkDir := t.TempDir()
+ dataDir := t.TempDir()
+ err := writeConfigFiles(linkDir, dataDir)
+ require.NoError(t, err)
+
+ for _, p := range []string{
+ filepath.Join(linkDir, serviceDir, updateServiceName),
+ filepath.Join(linkDir, serviceDir, updateTimerName),
+ } {
+ t.Run(filepath.Base(p), func(t *testing.T) {
+ data, err := os.ReadFile(p)
+ require.NoError(t, err)
+ data = replaceValues(data, map[string]string{
+ DefaultLinkDir: linkDir,
+ libdefaults.DataDir: dataDir,
+ })
+ if golden.ShouldSet() {
+ golden.Set(t, data)
+ }
+ require.Equal(t, string(golden.Get(t)), string(data))
+ })
+ }
+}
+
+func replaceValues(data []byte, m map[string]string) []byte {
+ for k, v := range m {
+ data = bytes.ReplaceAll(data, []byte(v),
+ []byte(k))
+ }
+ return data
+}
diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go
index 957e90779c2ab..2d31d26fd8262 100644
--- a/lib/autoupdate/agent/installer.go
+++ b/lib/autoupdate/agent/installer.go
@@ -55,11 +55,15 @@ const (
systemDirMode = 0755
)
-var (
+const (
// serviceDir contains the relative path to the Teleport SystemD service dir.
- serviceDir = filepath.Join("lib", "systemd", "system")
+ serviceDir = "lib/systemd/system"
// serviceName contains the name of the Teleport SystemD service file.
serviceName = "teleport.service"
+ // updateServiceName contains the name of the Teleport Update Systemd service
+ updateServiceName = "teleport-update.service"
+ // updateTimerName contains the name of the Teleport Update Systemd timer
+ updateTimerName = "teleport-update.timer"
)
// LocalInstaller manages the creation and removal of installations
@@ -539,7 +543,7 @@ func (li *LocalInstaller) forceLinks(ctx context.Context, binDir, svcDir string)
dst := filepath.Join(li.LinkServiceDir, serviceName)
orig, err := forceCopy(dst, src, maxServiceFileSize)
if err != nil && !errors.Is(err, os.ErrExist) {
- return revert, trace.Errorf("failed to create file for %s: %w", serviceName, err)
+ return revert, trace.Errorf("failed to write file %s: %w", serviceName, err)
}
if orig != nil {
revertFiles = append(revertFiles, *orig)
@@ -782,13 +786,5 @@ func (li *LocalInstaller) isLinked(versionDir string) (bool, error) {
return true, nil
}
}
- linkData, err := readFileN(filepath.Join(li.LinkServiceDir, serviceName), maxServiceFileSize)
- if err != nil {
- return false, nil
- }
- versionData, err := readFileN(filepath.Join(versionDir, serviceDir, serviceName), maxServiceFileSize)
- if err != nil {
- return false, nil
- }
- return bytes.Equal(linkData, versionData), nil
+ return false, nil
}
diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go
index eba70aa56a690..082e61156369b 100644
--- a/lib/autoupdate/agent/process.go
+++ b/lib/autoupdate/agent/process.go
@@ -25,25 +25,51 @@ import (
"log/slog"
"os"
"os/exec"
+ "strconv"
+ "syscall"
+ "time"
"github.com/gravitational/trace"
+ "golang.org/x/sync/errgroup"
)
-// SystemdService manages a Teleport systemd service.
+const (
+ // crashMonitorInterval is the polling interval for determining restart times from LastRestartPath.
+ crashMonitorInterval = 2 * time.Second
+ // minRunningIntervalsBeforeStable is the number of consecutive intervals with the same running PID detect
+ // before the service is determined stable.
+ minRunningIntervalsBeforeStable = 6
+ // maxCrashesBeforeFailure is the number of total crashes detected before the service is marked as crash-looping.
+ maxCrashesBeforeFailure = 2
+ // crashMonitorTimeout
+ crashMonitorTimeout = 30 * time.Second
+)
+
+// log keys
+const (
+ unitKey = "unit"
+)
+
+// SystemdService manages a systemd service (e.g., teleport or teleport-update).
type SystemdService struct {
// ServiceName specifies the systemd service name.
ServiceName string
+ // PIDPath is a path to a file containing the service's PID.
+ PIDPath string
// Log contains a logger.
Log *slog.Logger
}
-// Reload a systemd service.
+// Reload the systemd service.
// Attempts a graceful reload before a hard restart.
// See Process interface for more details.
func (s SystemdService) Reload(ctx context.Context) error {
+ // TODO(sclevine): allow server to force restart instead of reload
+
if err := s.checkSystem(ctx); err != nil {
return trace.Wrap(err)
}
+
// Command error codes < 0 indicate that we are unable to run the command.
// Errors from s.systemctl are logged along with stderr and stdout (debug only).
@@ -55,30 +81,165 @@ func (s SystemdService) Reload(ctx context.Context) error {
case code < 0:
return trace.Errorf("unable to determine if systemd service is active")
case code > 0:
- s.Log.WarnContext(ctx, "Teleport systemd service not running.")
+ s.Log.WarnContext(ctx, "Systemd service not running.", unitKey, s.ServiceName)
return trace.Wrap(ErrNotNeeded)
}
+
+ // Get initial PID for crash monitoring.
+
+ initPID, err := readInt(s.PIDPath)
+ if errors.Is(err, os.ErrNotExist) {
+ s.Log.InfoContext(ctx, "No existing process detected. Skipping crash monitoring.", unitKey, s.ServiceName)
+ } else if err != nil {
+ s.Log.ErrorContext(ctx, "Error reading initial PID value. Skipping crash monitoring.", unitKey, s.ServiceName, errorKey, err)
+ }
+
// Attempt graceful reload of running service.
code = s.systemctl(ctx, slog.LevelError, "reload", s.ServiceName)
switch {
case code < 0:
- return trace.Errorf("unable to attempt reload of Teleport systemd service")
+ return trace.Errorf("unable to reload systemd service")
case code > 0:
// Graceful reload fails, try hard restart.
code = s.systemctl(ctx, slog.LevelError, "try-restart", s.ServiceName)
if code != 0 {
- return trace.Errorf("hard restart of Teleport systemd service failed")
+ return trace.Errorf("hard restart of systemd service failed")
}
- s.Log.WarnContext(ctx, "Teleport ungracefully restarted. Connections potentially dropped.")
+ s.Log.WarnContext(ctx, "Service ungracefully restarted. Connections potentially dropped.", unitKey, s.ServiceName)
default:
- s.Log.InfoContext(ctx, "Teleport gracefully reloaded.")
+ s.Log.InfoContext(ctx, "Gracefully reloaded.", unitKey, s.ServiceName)
+ }
+ if initPID != 0 {
+ s.Log.InfoContext(ctx, "Monitoring PID file to detect crashes.", unitKey, s.ServiceName)
+ return trace.Wrap(s.monitor(ctx, initPID))
}
+ return nil
+}
- // TODO(sclevine): Ensure restart was successful and verify healthcheck.
+// monitor for the started process to ensure it's running by polling PIDFile.
+// This function detects several types of crashes while minimizing its own runtime during updates.
+// For example, the process may crash by failing to fork (non-running PID), or looping (repeatedly changing PID),
+// or getting stuck on quit (no change in PID).
+// initPID is the PID before the restart operation has been issued.
+func (s SystemdService) monitor(ctx context.Context, initPID int) error {
+ ctx, cancel := context.WithTimeout(ctx, crashMonitorTimeout)
+ defer cancel()
+ tickC := time.NewTicker(crashMonitorInterval).C
+ pidC := make(chan int)
+ g := &errgroup.Group{}
+ g.Go(func() error {
+ return tickFile(ctx, s.PIDPath, pidC, tickC)
+ })
+ err := s.waitForStablePID(ctx, minRunningIntervalsBeforeStable, maxCrashesBeforeFailure,
+ initPID, pidC, func(pid int) error {
+ p, err := os.FindProcess(pid)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return trace.Wrap(p.Signal(syscall.Signal(0)))
+ })
+ cancel()
+ if err := g.Wait(); err != nil {
+ s.Log.ErrorContext(ctx, "Error monitoring for crashing process.", errorKey, err, unitKey, s.ServiceName)
+ }
+ return trace.Wrap(err)
+}
+
+// monitorRestarts receives restart times on timeCh.
+// Each restart time that differs from the preceding restart time counts as a restart.
+// If maxRestarts is exceeded, monitorRestarts returns an error.
+// Each restart time that matches the proceeding restart time counts as a clean reading.
+// If minClean is reached before maxRestarts is exceeded, monitorRestarts runs nil.
+func (s SystemdService) waitForStablePID(ctx context.Context, minStable, maxCrashes, baselinePID int, pidC <-chan int, findPID func(pid int) error) error {
+ pid := baselinePID
+ var last, stale int
+ var crashes int
+ for stable := 0; stable < minStable; stable++ {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case p := <-pidC:
+ last = pid
+ pid = p
+ }
+ // A "crash" is defined as a transition away from a new (non-baseline) PID, or
+ // an interval where the current PID remains non-running since the last check.
+ if (last != 0 && pid != last && last != baselinePID) ||
+ (stale != 0 && pid == stale && last == stale) {
+ crashes++
+ }
+ if crashes > maxCrashes {
+ return trace.Errorf("detected crashing process")
+ }
+
+ // PID can only be stable if it is a real PID that is not new, has changed at least once,
+ // and hasn't been observed as missing.
+ if pid == 0 ||
+ pid == baselinePID ||
+ pid == stale ||
+ pid != last {
+ stable = -1
+ continue
+ }
+ err := findPID(pid)
+ // A stale PID most likely indicates that the process forked and crashed without systemd noticing.
+ // There is a small chance that we read the PID file before systemd removed it.
+ // Note: we only perform this check on PIDs that survive one iteration.
+ if errors.Is(err, os.ErrProcessDone) ||
+ errors.Is(err, syscall.ESRCH) {
+ if pid != stale &&
+ pid != baselinePID {
+ stale = pid
+ s.Log.WarnContext(ctx, "Detected stale PID.", unitKey, s.ServiceName, "pid", stale)
+ }
+ stable = -1
+ continue
+ }
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ }
return nil
}
+func readInt(path string) (int, error) {
+ p, err := readFileN(path, 32)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ i, err := strconv.ParseInt(string(bytes.TrimSpace(p)), 10, 64)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ return int(i), nil
+}
+
+// tickFile reads the current time on tickC, and outputs the last read int from path on ch for each received tick.
+// If the path cannot be read, tickFile sends 0 on ch.
+// Any error from the last attempt to read path is returned when ctx is canceled, unless the error is os.ErrNotExist.
+func tickFile(ctx context.Context, path string, ch chan<- int, tickC <-chan time.Time) error {
+ var err error
+ for {
+ // two select statements -> never skip reads
+ select {
+ case <-tickC:
+ case <-ctx.Done():
+ return err
+ }
+ var t int
+ t, err = readInt(path)
+ if errors.Is(err, os.ErrNotExist) {
+ err = nil
+ }
+ select {
+ case ch <- t:
+ case <-ctx.Done():
+ return err
+ }
+ }
+}
+
// Sync systemd service configuration by running systemctl daemon-reload.
// See Process interface for more details.
func (s SystemdService) Sync(ctx context.Context) error {
@@ -89,6 +250,24 @@ func (s SystemdService) Sync(ctx context.Context) error {
if code != 0 {
return trace.Errorf("unable to reload systemd configuration")
}
+ s.Log.InfoContext(ctx, "Systemd configuration synced.")
+ return nil
+}
+
+// Enable the systemd service.
+func (s SystemdService) Enable(ctx context.Context, now bool) error {
+ if err := s.checkSystem(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+ args := []string{"enable", s.ServiceName}
+ if now {
+ args = append(args, "--now")
+ }
+ code := s.systemctl(ctx, slog.LevelError, args...)
+ if code != 0 {
+ return trace.Errorf("unable to enable systemd service")
+ }
+ s.Log.InfoContext(ctx, "Service enabled.", unitKey, s.ServiceName)
return nil
}
@@ -106,9 +285,42 @@ func (s SystemdService) checkSystem(ctx context.Context) error {
// Output sent to stdout is logged at debug level.
// Output sent to stderr is logged at the level specified by errLevel.
func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args ...string) int {
- cmd := exec.CommandContext(ctx, "systemctl", args...)
- stderr := &lineLogger{ctx: ctx, log: s.Log, level: errLevel}
- stdout := &lineLogger{ctx: ctx, log: s.Log, level: slog.LevelDebug}
+ cmd := &localExec{
+ Log: s.Log,
+ ErrLevel: errLevel,
+ OutLevel: slog.LevelDebug,
+ }
+ code, err := cmd.Run(ctx, "systemctl", args...)
+ if err == nil {
+ return code
+ }
+ if code >= 0 {
+ s.Log.Log(ctx, errLevel, "Error running systemctl.",
+ "args", args, "code", code)
+ return code
+ }
+ s.Log.Log(ctx, errLevel, "Unable to run systemctl.",
+ "args", args, "code", code, errorKey, err)
+ return code
+}
+
+// localExec runs a command locally, logging any output.
+type localExec struct {
+ // Log contains a slog logger.
+ // Defaults to slog.Default() if nil.
+ Log *slog.Logger
+ // ErrLevel is the log level for stderr.
+ ErrLevel slog.Level
+ // OutLevel is the log level for stdout.
+ OutLevel slog.Level
+}
+
+// Run the command. Same arguments as exec.CommandContext.
+// Outputs the status code, or -1 if out-of-range or unstarted.
+func (c *localExec) Run(ctx context.Context, name string, args ...string) (int, error) {
+ cmd := exec.CommandContext(ctx, name, args...)
+ stderr := &lineLogger{ctx: ctx, log: c.Log, level: c.ErrLevel, prefix: "[stderr] "}
+ stdout := &lineLogger{ctx: ctx, log: c.Log, level: c.OutLevel, prefix: "[stdout] "}
cmd.Stderr = stderr
cmd.Stdout = stdout
err := cmd.Run()
@@ -122,24 +334,23 @@ func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args
if code == 255 {
code = -1
}
- if err != nil {
- s.Log.Log(ctx, errLevel, "Failed to run systemctl.",
- "args", args,
- "code", code,
- "error", err)
- }
- return code
+ return code, trace.Wrap(err)
}
// lineLogger logs each line written to it.
type lineLogger struct {
- ctx context.Context
- log *slog.Logger
- level slog.Level
+ ctx context.Context
+ log *slog.Logger
+ level slog.Level
+ prefix string
last bytes.Buffer
}
+func (w *lineLogger) out(s string) {
+ w.log.Log(w.ctx, w.level, w.prefix+s) //nolint:sloglint // msg cannot be constant
+}
+
func (w *lineLogger) Write(p []byte) (n int, err error) {
lines := bytes.Split(p, []byte("\n"))
// Finish writing line
@@ -153,13 +364,13 @@ func (w *lineLogger) Write(p []byte) (n int, err error) {
}
// Newline found, log line
- w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant
+ w.out(w.last.String())
n += 1
w.last.Reset()
// Log lines that are already newline-terminated
for _, line := range lines[:len(lines)-1] {
- w.log.Log(w.ctx, w.level, string(line)) //nolint:sloglint // msg cannot be constant
+ w.out(string(line))
n += len(line) + 1
}
@@ -174,6 +385,6 @@ func (w *lineLogger) Flush() {
if w.last.Len() == 0 {
return
}
- w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant
+ w.out(w.last.String())
w.last.Reset()
}
diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go
index 5ffa70dd0091e..c558a7539831a 100644
--- a/lib/autoupdate/agent/process_test.go
+++ b/lib/autoupdate/agent/process_test.go
@@ -21,8 +21,13 @@ package agent
import (
"bytes"
"context"
+ "errors"
+ "fmt"
"log/slog"
+ "os"
+ "path/filepath"
"testing"
+ "time"
"github.com/stretchr/testify/require"
)
@@ -69,3 +74,266 @@ func msgOnly(_ []string, a slog.Attr) slog.Attr {
}
return slog.Attr{Key: a.Key, Value: a.Value}
}
+
+func TestWaitForStablePID(t *testing.T) {
+ t.Parallel()
+
+ svc := &SystemdService{
+ Log: slog.Default(),
+ }
+
+ for _, tt := range []struct {
+ name string
+ ticks []int
+ baseline int
+ minStable int
+ maxCrashes int
+ findErrs map[int]error
+
+ errored bool
+ canceled bool
+ }{
+ {
+ name: "immediate restart",
+ ticks: []int{2, 2},
+ baseline: 1,
+ minStable: 1,
+ maxCrashes: 1,
+ },
+ {
+ name: "zero stable",
+ },
+ {
+ name: "immediate crash",
+ ticks: []int{2, 3},
+ baseline: 1,
+ minStable: 1,
+ maxCrashes: 0,
+ errored: true,
+ },
+ {
+ name: "no changes times out",
+ ticks: []int{1, 1, 1, 1},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ canceled: true,
+ },
+ {
+ name: "baseline restart",
+ ticks: []int{2, 2, 2, 2},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ },
+ {
+ name: "one restart then stable",
+ ticks: []int{1, 2, 2, 2, 2},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ },
+ {
+ name: "two restarts then stable",
+ ticks: []int{1, 2, 3, 3, 3, 3},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ },
+ {
+ name: "three restarts then stable",
+ ticks: []int{1, 2, 3, 4, 4, 4, 4},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ },
+ {
+ name: "too many restarts excluding baseline",
+ ticks: []int{1, 2, 3, 4, 5},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ errored: true,
+ },
+ {
+ name: "too many restarts including baseline",
+ ticks: []int{1, 2, 3, 4},
+ baseline: 0,
+ minStable: 3,
+ maxCrashes: 2,
+ errored: true,
+ },
+ {
+ name: "too many restarts slow",
+ ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4},
+ baseline: 0,
+ minStable: 3,
+ maxCrashes: 2,
+ errored: true,
+ },
+ {
+ name: "too many restarts after stable",
+ ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4},
+ baseline: 0,
+ minStable: 3,
+ maxCrashes: 2,
+ },
+ {
+ name: "stable after too many restarts",
+ ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4},
+ baseline: 0,
+ minStable: 3,
+ maxCrashes: 2,
+ errored: true,
+ },
+ {
+ name: "cancel",
+ ticks: []int{1, 1, 1},
+ baseline: 0,
+ minStable: 3,
+ maxCrashes: 2,
+ canceled: true,
+ },
+ {
+ name: "stale PID crash",
+ ticks: []int{2, 2, 2, 2, 2},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ findErrs: map[int]error{
+ 2: os.ErrProcessDone,
+ },
+ errored: true,
+ },
+ {
+ name: "stale PID but fixed",
+ ticks: []int{2, 2, 3, 3, 3, 3},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ findErrs: map[int]error{
+ 2: os.ErrProcessDone,
+ },
+ },
+ {
+ name: "error PID",
+ ticks: []int{2, 2, 3, 3, 3, 3},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ findErrs: map[int]error{
+ 2: errors.New("bad"),
+ },
+ errored: true,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := context.Background()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ ch := make(chan int)
+ go func() {
+ defer cancel() // always quit after last tick
+ for _, tick := range tt.ticks {
+ ch <- tick
+ }
+ }()
+ err := svc.waitForStablePID(ctx, tt.minStable, tt.maxCrashes,
+ tt.baseline, ch, func(pid int) error {
+ return tt.findErrs[pid]
+ })
+ require.Equal(t, tt.canceled, errors.Is(err, context.Canceled))
+ if !tt.canceled {
+ require.Equal(t, tt.errored, err != nil)
+ }
+ })
+ }
+}
+
+func TestTickFile(t *testing.T) {
+ t.Parallel()
+
+ for _, tt := range []struct {
+ name string
+ ticks []int
+ errored bool
+ }{
+ {
+ name: "consistent",
+ ticks: []int{1, 1, 1},
+ errored: false,
+ },
+ {
+ name: "divergent",
+ ticks: []int{1, 2, 3},
+ errored: false,
+ },
+ {
+ name: "start error",
+ ticks: []int{-1, 1, 1},
+ errored: false,
+ },
+ {
+ name: "ephemeral error",
+ ticks: []int{1, -1, 1},
+ errored: false,
+ },
+ {
+ name: "end error",
+ ticks: []int{1, 1, -1},
+ errored: true,
+ },
+ {
+ name: "start missing",
+ ticks: []int{0, 1, 1},
+ errored: false,
+ },
+ {
+ name: "ephemeral missing",
+ ticks: []int{1, 0, 1},
+ errored: false,
+ },
+ {
+ name: "end missing",
+ ticks: []int{1, 1, 0},
+ errored: false,
+ },
+ {
+ name: "cancel-only",
+ errored: false,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ filePath := filepath.Join(t.TempDir(), "file")
+
+ ctx := context.Background()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ tickC := make(chan time.Time)
+ ch := make(chan int)
+
+ go func() {
+ defer cancel() // always quit after last tick or fail
+ for _, tick := range tt.ticks {
+ _ = os.RemoveAll(filePath)
+ switch {
+ case tick > 0:
+ err := os.WriteFile(filePath, []byte(fmt.Sprintln(tick)), os.ModePerm)
+ require.NoError(t, err)
+ case tick < 0:
+ err := os.Mkdir(filePath, os.ModePerm)
+ require.NoError(t, err)
+ }
+ tickC <- time.Now()
+ res := <-ch
+ if tick < 0 {
+ tick = 0
+ }
+ require.Equal(t, tick, res)
+ }
+ }()
+ err := tickFile(ctx, filePath, ch, tickC)
+ require.Equal(t, tt.errored, err != nil)
+ })
+ }
+}
diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden
new file mode 100644
index 0000000000000..185b4f07a1aa9
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden
@@ -0,0 +1,7 @@
+# teleport-update
+[Unit]
+Description=Teleport auto-update service
+
+[Service]
+Type=oneshot
+ExecStart=/usr/local/bin/teleport-update update
diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden
new file mode 100644
index 0000000000000..acca095d9825f
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden
@@ -0,0 +1,11 @@
+# teleport-update
+[Unit]
+Description=Teleport auto-update timer unit
+
+[Timer]
+OnActiveSec=1m
+OnUnitActiveSec=5m
+RandomizedDelaySec=1m
+
+[Install]
+WantedBy=teleport.service
diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go
index 9625481df2cd2..49dfe40fd27da 100644
--- a/lib/autoupdate/agent/updater.go
+++ b/lib/autoupdate/agent/updater.go
@@ -27,6 +27,7 @@ import (
"log/slog"
"net/http"
"os"
+ "os/exec"
"path/filepath"
"strings"
"time"
@@ -46,6 +47,10 @@ const (
DefaultLinkDir = "/usr/local"
// DefaultSystemDir is the location where packaged Teleport binaries and services are installed.
DefaultSystemDir = "/usr/local/teleport-system"
+ // VersionsDirName specifies the name of the subdirectory inside the Teleport data dir for storing Teleport versions.
+ VersionsDirName = "versions"
+ // BinaryName specifies the name of the updater binary.
+ BinaryName = "teleport-update"
)
const (
@@ -136,16 +141,20 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) {
if cfg.SystemDir == "" {
cfg.SystemDir = DefaultSystemDir
}
- if cfg.VersionsDir == "" {
- cfg.VersionsDir = filepath.Join(libdefaults.DataDir, "versions")
+ if cfg.DataDir == "" {
+ cfg.DataDir = libdefaults.DataDir
+ }
+ installDir := filepath.Join(cfg.DataDir, VersionsDirName)
+ if err := os.MkdirAll(installDir, systemDirMode); err != nil {
+ return nil, trace.Errorf("failed to create install directory: %w", err)
}
return &Updater{
Log: cfg.Log,
Pool: certPool,
InsecureSkipVerify: cfg.InsecureSkipVerify,
- ConfigPath: filepath.Join(cfg.VersionsDir, updateConfigName),
+ ConfigPath: filepath.Join(installDir, updateConfigName),
Installer: &LocalInstaller{
- InstallDir: cfg.VersionsDir,
+ InstallDir: installDir,
LinkBinDir: filepath.Join(cfg.LinkDir, "bin"),
// For backwards-compatibility with symlinks created by package-based installs, we always
// link into /lib/systemd/system, even though, e.g., /usr/local/lib/systemd/system would work.
@@ -159,8 +168,24 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) {
},
Process: &SystemdService{
ServiceName: "teleport.service",
+ PIDPath: "/run/teleport.pid",
Log: cfg.Log,
},
+ Setup: func(ctx context.Context) error {
+ name := filepath.Join(cfg.LinkDir, "bin", BinaryName)
+ if cfg.SelfSetup {
+ name = "/proc/self/exe"
+ }
+ cmd := exec.CommandContext(ctx, name,
+ "--data-dir", cfg.DataDir,
+ "--link-dir", cfg.LinkDir,
+ "setup")
+ cmd.Stderr = os.Stderr
+ cmd.Stdout = os.Stdout
+ cfg.Log.InfoContext(ctx, "Executing new teleport-update binary to update configuration.")
+ defer cfg.Log.InfoContext(ctx, "Finished executing new teleport-update binary.")
+ return trace.Wrap(cmd.Run())
+ },
}, nil
}
@@ -174,12 +199,14 @@ type LocalUpdaterConfig struct {
// DownloadTimeout is a timeout for file download requests.
// Defaults to no timeout.
DownloadTimeout time.Duration
- // VersionsDir for installing Teleport (usually /var/lib/teleport/versions).
- VersionsDir string
+ // DataDir for Teleport (usually /var/lib/teleport).
+ DataDir string
// LinkDir for installing Teleport (usually /usr/local).
LinkDir string
// SystemDir for package-installed Teleport installations (usually /usr/local/teleport-system).
SystemDir string
+ // SelfSetup mode for using the current version of the teleport-update to setup the update service.
+ SelfSetup bool
}
// Updater implements the agent-local logic for Teleport agent auto-updates.
@@ -196,6 +223,8 @@ type Updater struct {
Installer Installer
// Process manages a running instance of Teleport.
Process Process
+ // Setup installs the Teleport updater service using the linked installation.
+ Setup func(ctx context.Context) error
}
// Installer provides an API for installing Teleport agents.
@@ -336,6 +365,8 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error {
return trace.Errorf("agent version not available from Teleport cluster")
}
+ u.Log.InfoContext(ctx, "Initiating initial update.", targetVersionKey, targetVersion, activeVersionKey, cfg.Status.ActiveVersion)
+
if err := u.update(ctx, cfg, targetVersion, flags); err != nil {
return trace.Wrap(err)
}
@@ -476,7 +507,7 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s
}
}
- // Install the desired version (or validate existing installation)
+ // Install and link the desired version (or validate existing installation)
template := cfg.Spec.URLTemplate
if template == "" {
@@ -486,20 +517,38 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s
if err != nil {
return trace.Errorf("failed to install: %w", err)
}
+
+ // TODO(slevine): if the target version has fewer binaries, this will
+ // leave old binaries linked. This may prevent the installation from
+ // being removed. To fix this, we should look for orphaned binaries
+ // and remove them.
+
revert, err := u.Installer.Link(ctx, targetVersion)
if err != nil {
return trace.Errorf("failed to link: %w", err)
}
+ // Verify that the linked installation contains a valid updater binary,
+ // and use it to update the updater's service files.
+
+ if err := u.Setup(ctx); err != nil {
+ if ok := revert(ctx); !ok {
+ u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks. Installation likely broken.")
+ }
+ return trace.Errorf("failed to setup updater: %w", err)
+ }
+
// If we fail to revert after this point, the next update/enable will
// fix the link to restore the active version.
// Sync process configuration after linking.
- if err := u.Process.Sync(ctx); err != nil {
- if errors.Is(err, context.Canceled) {
- return trace.Errorf("sync canceled")
- }
+ err = u.Process.Sync(ctx)
+ if errors.Is(err, ErrNotSupported) {
+ u.Log.WarnContext(ctx, "Not syncing systemd configuration because systemd is not running.")
+ } else if errors.Is(err, context.Canceled) {
+ return trace.Errorf("sync canceled")
+ } else if err != nil {
// If sync fails, we may have left the host in a bad state, so we revert linking and re-Sync.
u.Log.ErrorContext(ctx, "Reverting symlinks due to invalid configuration.")
if ok := revert(ctx); !ok {
@@ -517,10 +566,14 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s
if cfg.Status.ActiveVersion != targetVersion {
u.Log.InfoContext(ctx, "Target version successfully installed.", targetVersionKey, targetVersion)
- if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) {
- if errors.Is(err, context.Canceled) {
- return trace.Errorf("reload canceled")
- }
+ err := u.Process.Reload(ctx)
+ if errors.Is(err, context.Canceled) {
+ return trace.Errorf("reload canceled")
+ }
+ if err != nil &&
+ !errors.Is(err, ErrNotNeeded) && // no output if restart not needed
+ !errors.Is(err, ErrNotSupported) { // already logged above for Sync
+
// If reloading Teleport at the new version fails, revert, resync, and reload.
u.Log.ErrorContext(ctx, "Reverting symlinks due to failed restart.")
if ok := revert(ctx); !ok {
@@ -608,7 +661,11 @@ func validateConfigSpec(spec *UpdateSpec, override OverrideConfig) error {
if override.Group != "" {
spec.Group = override.Group
}
- if override.URLTemplate != "" {
+ switch override.URLTemplate {
+ case "":
+ case "default":
+ spec.URLTemplate = ""
+ default:
spec.URLTemplate = override.URLTemplate
}
if spec.URLTemplate != "" &&
diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go
index 1197ac3d5a795..8f919762f53fb 100644
--- a/lib/autoupdate/agent/updater_test.go
+++ b/lib/autoupdate/agent/updater_test.go
@@ -83,7 +83,13 @@ func TestUpdater_Disable(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
- cfgPath := filepath.Join(dir, "update.yaml")
+ cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml")
+
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ DataDir: dir,
+ })
+ require.NoError(t, err)
// Create config file only if provided in test case
if tt.cfg != nil {
@@ -92,11 +98,7 @@ func TestUpdater_Disable(t *testing.T) {
err = os.WriteFile(cfgPath, b, 0600)
require.NoError(t, err)
}
- updater, err := NewLocalUpdater(LocalUpdaterConfig{
- InsecureSkipVerify: true,
- VersionsDir: dir,
- })
- require.NoError(t, err)
+
err = updater.Disable(context.Background())
if tt.errMatch != "" {
require.Error(t, err)
@@ -142,6 +144,7 @@ func TestUpdater_Update(t *testing.T) {
syncCalls int
reloadCalls int
revertCalls int
+ setupCalls int
errMatch string
}{
{
@@ -166,6 +169,7 @@ func TestUpdater_Update(t *testing.T) {
requestGroup: "group",
syncCalls: 1,
reloadCalls: 1,
+ setupCalls: 1,
},
{
name: "updates disabled during window",
@@ -295,6 +299,7 @@ func TestUpdater_Update(t *testing.T) {
removedVersion: "backup-version",
syncCalls: 1,
reloadCalls: 1,
+ setupCalls: 1,
},
{
name: "backup version kept when no change",
@@ -338,6 +343,7 @@ func TestUpdater_Update(t *testing.T) {
removedVersion: "backup-version",
syncCalls: 1,
reloadCalls: 1,
+ setupCalls: 1,
},
{
name: "invalid metadata",
@@ -368,6 +374,7 @@ func TestUpdater_Update(t *testing.T) {
syncCalls: 2,
reloadCalls: 0,
revertCalls: 1,
+ setupCalls: 1,
errMatch: "sync error",
},
{
@@ -394,6 +401,7 @@ func TestUpdater_Update(t *testing.T) {
syncCalls: 2,
reloadCalls: 2,
revertCalls: 1,
+ setupCalls: 1,
errMatch: "reload error",
},
}
@@ -419,7 +427,13 @@ func TestUpdater_Update(t *testing.T) {
t.Cleanup(server.Close)
dir := t.TempDir()
- cfgPath := filepath.Join(dir, "update.yaml")
+ cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml")
+
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ DataDir: dir,
+ })
+ require.NoError(t, err)
// Create config file only if provided in test case
if tt.cfg != nil {
@@ -430,12 +444,6 @@ func TestUpdater_Update(t *testing.T) {
require.NoError(t, err)
}
- updater, err := NewLocalUpdater(LocalUpdaterConfig{
- InsecureSkipVerify: true,
- VersionsDir: dir,
- })
- require.NoError(t, err)
-
var (
installedVersion string
installedTemplate string
@@ -481,6 +489,12 @@ func TestUpdater_Update(t *testing.T) {
},
}
+ var setupCalls int
+ updater.Setup = func(_ context.Context) error {
+ setupCalls++
+ return nil
+ }
+
ctx := context.Background()
err = updater.Update(ctx)
if tt.errMatch != "" {
@@ -498,6 +512,7 @@ func TestUpdater_Update(t *testing.T) {
require.Equal(t, tt.syncCalls, syncCalls)
require.Equal(t, tt.reloadCalls, reloadCalls)
require.Equal(t, tt.revertCalls, revertCalls)
+ require.Equal(t, tt.setupCalls, setupCalls)
if tt.cfg == nil {
_, err := os.Stat(cfgPath)
@@ -594,7 +609,13 @@ func TestUpdater_LinkPackage(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
- cfgPath := filepath.Join(dir, "update.yaml")
+ cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml")
+
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ DataDir: dir,
+ })
+ require.NoError(t, err)
// Create config file only if provided in test case
if tt.cfg != nil {
@@ -604,12 +625,6 @@ func TestUpdater_LinkPackage(t *testing.T) {
require.NoError(t, err)
}
- updater, err := NewLocalUpdater(LocalUpdaterConfig{
- InsecureSkipVerify: true,
- VersionsDir: dir,
- })
- require.NoError(t, err)
-
var tryLinkSystemCalls int
updater.Installer = &testInstaller{
FuncTryLinkSystem: func(_ context.Context) error {
@@ -659,6 +674,7 @@ func TestUpdater_Enable(t *testing.T) {
syncCalls int
reloadCalls int
revertCalls int
+ setupCalls int
errMatch string
}{
{
@@ -681,6 +697,7 @@ func TestUpdater_Enable(t *testing.T) {
requestGroup: "group",
syncCalls: 1,
reloadCalls: 1,
+ setupCalls: 1,
},
{
name: "config from user",
@@ -706,6 +723,7 @@ func TestUpdater_Enable(t *testing.T) {
linkedVersion: "new-version",
syncCalls: 1,
reloadCalls: 1,
+ setupCalls: 1,
},
{
name: "already enabled",
@@ -725,6 +743,7 @@ func TestUpdater_Enable(t *testing.T) {
linkedVersion: "16.3.0",
syncCalls: 1,
reloadCalls: 1,
+ setupCalls: 1,
},
{
name: "insecure URL",
@@ -766,6 +785,7 @@ func TestUpdater_Enable(t *testing.T) {
linkedVersion: "16.3.0",
syncCalls: 1,
reloadCalls: 0,
+ setupCalls: 1,
},
{
name: "backup version removed on install",
@@ -784,6 +804,7 @@ func TestUpdater_Enable(t *testing.T) {
removedVersion: "backup-version",
syncCalls: 1,
reloadCalls: 1,
+ setupCalls: 1,
},
{
name: "backup version kept for validation",
@@ -802,6 +823,7 @@ func TestUpdater_Enable(t *testing.T) {
removedVersion: "",
syncCalls: 1,
reloadCalls: 0,
+ setupCalls: 1,
},
{
name: "config does not exist",
@@ -811,6 +833,7 @@ func TestUpdater_Enable(t *testing.T) {
linkedVersion: "16.3.0",
syncCalls: 1,
reloadCalls: 1,
+ setupCalls: 1,
},
{
name: "FIPS and Enterprise flags",
@@ -820,6 +843,7 @@ func TestUpdater_Enable(t *testing.T) {
linkedVersion: "16.3.0",
syncCalls: 1,
reloadCalls: 1,
+ setupCalls: 1,
},
{
name: "invalid metadata",
@@ -836,6 +860,7 @@ func TestUpdater_Enable(t *testing.T) {
syncCalls: 2,
reloadCalls: 0,
revertCalls: 1,
+ setupCalls: 1,
errMatch: "sync error",
},
{
@@ -848,6 +873,7 @@ func TestUpdater_Enable(t *testing.T) {
syncCalls: 2,
reloadCalls: 2,
revertCalls: 1,
+ setupCalls: 1,
errMatch: "reload error",
},
}
@@ -855,7 +881,13 @@ func TestUpdater_Enable(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
- cfgPath := filepath.Join(dir, "update.yaml")
+ cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml")
+
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ DataDir: dir,
+ })
+ require.NoError(t, err)
// Create config file only if provided in test case
if tt.cfg != nil {
@@ -886,12 +918,6 @@ func TestUpdater_Enable(t *testing.T) {
tt.userCfg.Proxy = strings.TrimPrefix(server.URL, "https://")
}
- updater, err := NewLocalUpdater(LocalUpdaterConfig{
- InsecureSkipVerify: true,
- VersionsDir: dir,
- })
- require.NoError(t, err)
-
var (
installedVersion string
installedTemplate string
@@ -936,6 +962,11 @@ func TestUpdater_Enable(t *testing.T) {
return tt.reloadErr
},
}
+ var setupCalls int
+ updater.Setup = func(_ context.Context) error {
+ setupCalls++
+ return nil
+ }
ctx := context.Background()
err = updater.Enable(ctx, tt.userCfg)
@@ -954,6 +985,7 @@ func TestUpdater_Enable(t *testing.T) {
require.Equal(t, tt.syncCalls, syncCalls)
require.Equal(t, tt.reloadCalls, reloadCalls)
require.Equal(t, tt.revertCalls, revertCalls)
+ require.Equal(t, tt.setupCalls, setupCalls)
if tt.cfg == nil && err != nil {
_, err := os.Stat(cfgPath)
diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go
index d559ad3e75cdd..c2d8f07165d25 100644
--- a/tool/teleport-update/main.go
+++ b/tool/teleport-update/main.go
@@ -21,6 +21,7 @@ package main
import (
"context"
"errors"
+ "fmt"
"log/slog"
"os"
"os/signal"
@@ -58,10 +59,8 @@ const (
)
const (
- // versionsDirName specifies the name of the subdirectory inside of the Teleport data dir for storing Teleport versions.
- versionsDirName = "versions"
- // lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution.
- lockFileName = ".lock"
+ // lockFileName specifies the name of the file containing the flock lock preventing concurrent updater execution.
+ lockFileName = ".update-lock"
)
var plog = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentUpdater)
@@ -84,6 +83,8 @@ type cliConfig struct {
DataDir string
// LinkDir for linking binaries and systemd services
LinkDir string
+ // SelfSetup mode for using the current version of the teleport-update to setup the update service.
+ SelfSetup bool
}
func Run(args []string) error {
@@ -91,7 +92,7 @@ func Run(args []string) error {
ctx := context.Background()
ctx, _ = signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
- app := libutils.InitCLIParser("teleport-update", appHelp).Interspersed(false)
+ app := libutils.InitCLIParser(autoupdate.BinaryName, appHelp).Interspersed(false)
app.Flag("debug", "Verbose logging to stdout.").
Short('d').BoolVar(&ccfg.Debug)
app.Flag("data-dir", "Teleport data directory. Access to this directory should be limited.").
@@ -103,7 +104,7 @@ func Run(args []string) error {
app.HelpFlag.Short('h')
- versionCmd := app.Command("version", "Print the version of your teleport-updater binary.")
+ versionCmd := app.Command("version", fmt.Sprintf("Print the version of your %s binary.", autoupdate.BinaryName))
enableCmd := app.Command("enable", "Enable agent auto-updates and perform initial update.")
enableCmd.Flag("proxy", "Address of the Teleport Proxy.").
@@ -114,14 +115,21 @@ func Run(args []string) error {
Short('t').Envar(templateEnvVar).StringVar(&ccfg.URLTemplate)
enableCmd.Flag("force-version", "Force the provided version instead of querying it from the Teleport cluster.").
Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion)
+ enableCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for auto-updates.").
+ Short('s').Hidden().BoolVar(&ccfg.SelfSetup)
// TODO(sclevine): add force-fips and force-enterprise as hidden flags
disableCmd := app.Command("disable", "Disable agent auto-updates.")
updateCmd := app.Command("update", "Update agent to the latest version, if a new version is available.")
+ updateCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for auto-updates.").
+ Short('s').Hidden().BoolVar(&ccfg.SelfSetup)
linkCmd := app.Command("link", "Link the system installation of Teleport from the Teleport package, if auto-updates is disabled.")
+ setupCmd := app.Command("setup", "Write configuration files that run the update subcommand on a timer.").
+ Hidden()
+
libutils.UpdateAppUsageTemplate(app, args)
command, err := app.Parse(args)
if err != nil {
@@ -143,6 +151,8 @@ func Run(args []string) error {
err = cmdUpdate(ctx, &ccfg)
case linkCmd.FullCommand():
err = cmdLink(ctx, &ccfg)
+ case setupCmd.FullCommand():
+ err = cmdSetup(ctx, &ccfg)
case versionCmd.FullCommand():
modules.GetModules().PrintVersion()
default:
@@ -172,12 +182,17 @@ func setupLogger(debug bool, format string) error {
// cmdDisable disables updates.
func cmdDisable(ctx context.Context, ccfg *cliConfig) error {
- versionsDir := filepath.Join(ccfg.DataDir, versionsDirName)
- if err := os.MkdirAll(versionsDir, 0755); err != nil {
- return trace.Errorf("failed to create versions directory: %w", err)
+ updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
+ DataDir: ccfg.DataDir,
+ LinkDir: ccfg.LinkDir,
+ SystemDir: autoupdate.DefaultSystemDir,
+ SelfSetup: ccfg.SelfSetup,
+ Log: plog,
+ })
+ if err != nil {
+ return trace.Errorf("failed to setup updater: %w", err)
}
-
- unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName))
+ unlock, err := libutils.FSWriteLock(filepath.Join(ccfg.DataDir, lockFileName))
if err != nil {
return trace.Errorf("failed to grab concurrent execution lock: %w", err)
}
@@ -186,15 +201,6 @@ func cmdDisable(ctx context.Context, ccfg *cliConfig) error {
plog.DebugContext(ctx, "Failed to close lock file", "error", err)
}
}()
- updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
- VersionsDir: versionsDir,
- LinkDir: ccfg.LinkDir,
- SystemDir: autoupdate.DefaultSystemDir,
- Log: plog,
- })
- if err != nil {
- return trace.Errorf("failed to setup updater: %w", err)
- }
if err := updater.Disable(ctx); err != nil {
return trace.Wrap(err)
}
@@ -203,13 +209,19 @@ func cmdDisable(ctx context.Context, ccfg *cliConfig) error {
// cmdEnable enables updates and triggers an initial update.
func cmdEnable(ctx context.Context, ccfg *cliConfig) error {
- versionsDir := filepath.Join(ccfg.DataDir, versionsDirName)
- if err := os.MkdirAll(versionsDir, 0755); err != nil {
- return trace.Errorf("failed to create versions directory: %w", err)
+ updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
+ DataDir: ccfg.DataDir,
+ LinkDir: ccfg.LinkDir,
+ SystemDir: autoupdate.DefaultSystemDir,
+ SelfSetup: ccfg.SelfSetup,
+ Log: plog,
+ })
+ if err != nil {
+ return trace.Errorf("failed to setup updater: %w", err)
}
// Ensure enable can't run concurrently.
- unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName))
+ unlock, err := libutils.FSWriteLock(filepath.Join(ccfg.DataDir, lockFileName))
if err != nil {
return trace.Errorf("failed to grab concurrent execution lock: %w", err)
}
@@ -218,16 +230,6 @@ func cmdEnable(ctx context.Context, ccfg *cliConfig) error {
plog.DebugContext(ctx, "Failed to close lock file", "error", err)
}
}()
-
- updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
- VersionsDir: versionsDir,
- LinkDir: ccfg.LinkDir,
- SystemDir: autoupdate.DefaultSystemDir,
- Log: plog,
- })
- if err != nil {
- return trace.Errorf("failed to setup updater: %w", err)
- }
if err := updater.Enable(ctx, ccfg.OverrideConfig); err != nil {
return trace.Wrap(err)
}
@@ -236,13 +238,18 @@ func cmdEnable(ctx context.Context, ccfg *cliConfig) error {
// cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address.
func cmdUpdate(ctx context.Context, ccfg *cliConfig) error {
- versionsDir := filepath.Join(ccfg.DataDir, versionsDirName)
- if err := os.MkdirAll(versionsDir, 0755); err != nil {
- return trace.Errorf("failed to create versions directory: %w", err)
+ updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
+ DataDir: ccfg.DataDir,
+ LinkDir: ccfg.LinkDir,
+ SystemDir: autoupdate.DefaultSystemDir,
+ SelfSetup: ccfg.SelfSetup,
+ Log: plog,
+ })
+ if err != nil {
+ return trace.Errorf("failed to setup updater: %w", err)
}
-
// Ensure update can't run concurrently.
- unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName))
+ unlock, err := libutils.FSWriteLock(filepath.Join(ccfg.DataDir, lockFileName))
if err != nil {
return trace.Errorf("failed to grab concurrent execution lock: %w", err)
}
@@ -252,15 +259,6 @@ func cmdUpdate(ctx context.Context, ccfg *cliConfig) error {
}
}()
- updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
- VersionsDir: versionsDir,
- LinkDir: ccfg.LinkDir,
- SystemDir: autoupdate.DefaultSystemDir,
- Log: plog,
- })
- if err != nil {
- return trace.Errorf("failed to setup updater: %w", err)
- }
if err := updater.Update(ctx); err != nil {
return trace.Wrap(err)
}
@@ -269,10 +267,19 @@ func cmdUpdate(ctx context.Context, ccfg *cliConfig) error {
// cmdLink creates system package links if no version is linked and auto-updates is disabled.
func cmdLink(ctx context.Context, ccfg *cliConfig) error {
- versionsDir := filepath.Join(ccfg.DataDir, versionsDirName)
+ updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
+ DataDir: ccfg.DataDir,
+ LinkDir: ccfg.LinkDir,
+ SystemDir: autoupdate.DefaultSystemDir,
+ SelfSetup: ccfg.SelfSetup,
+ Log: plog,
+ })
+ if err != nil {
+ return trace.Errorf("failed to setup updater: %w", err)
+ }
// Skip operation and warn if the updater is currently running.
- unlock, err := libutils.FSTryReadLock(filepath.Join(versionsDir, lockFileName))
+ unlock, err := libutils.FSTryReadLock(filepath.Join(ccfg.DataDir, lockFileName))
if errors.Is(err, libutils.ErrUnsuccessfulLockTry) {
plog.WarnContext(ctx, "Updater is currently running. Skipping package linking.")
return nil
@@ -285,17 +292,18 @@ func cmdLink(ctx context.Context, ccfg *cliConfig) error {
plog.DebugContext(ctx, "Failed to close lock file", "error", err)
}
}()
- updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
- VersionsDir: versionsDir,
- LinkDir: ccfg.LinkDir,
- SystemDir: autoupdate.DefaultSystemDir,
- Log: plog,
- })
- if err != nil {
- return trace.Errorf("failed to setup updater: %w", err)
- }
+
if err := updater.LinkPackage(ctx); err != nil {
return trace.Wrap(err)
}
return nil
}
+
+// cmdSetup writes configuration files that are needed to run teleport-update update.
+func cmdSetup(ctx context.Context, ccfg *cliConfig) error {
+ err := autoupdate.Setup(ctx, plog, ccfg.LinkDir, ccfg.DataDir)
+ if err != nil {
+ return trace.Errorf("failed to setup teleport-update service: %w", err)
+ }
+ return nil
+}