diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go new file mode 100644 index 0000000000000..6664fcaa94131 --- /dev/null +++ b/lib/autoupdate/agent/config.go @@ -0,0 +1,122 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "context" + "errors" + "log/slog" + "os" + "path/filepath" + "text/template" + + "github.com/google/renameio/v2" + "github.com/gravitational/trace" +) + +const ( + updateServiceTemplate = `# teleport-update +[Unit] +Description=Teleport auto-update service + +[Service] +Type=oneshot +ExecStart={{.LinkDir}}/bin/teleport-update update +` + updateTimerTemplate = `# teleport-update +[Unit] +Description=Teleport auto-update timer unit + +[Timer] +OnActiveSec=1m +OnUnitActiveSec=5m +RandomizedDelaySec=1m + +[Install] +WantedBy=teleport.service +` +) + +func Setup(ctx context.Context, log *slog.Logger, linkDir, dataDir string) error { + err := writeConfigFiles(linkDir, dataDir) + if err != nil { + return trace.Errorf("failed to write teleport-update systemd config files: %w", err) + } + svc := &SystemdService{ + ServiceName: "teleport-update.timer", + Log: log, + } + err = svc.Sync(ctx) + if errors.Is(err, ErrNotSupported) { + log.WarnContext(ctx, "Not enabling systemd service because systemd is not running.") + return nil + } + if err != nil { + return trace.Errorf("failed to sync systemd config: %w", err) + } + if err := svc.Enable(ctx, true); err != nil { + return trace.Errorf("failed to enable teleport-update systemd timer: %w", err) + } + return nil +} + +func writeConfigFiles(linkDir, dataDir string) error { + // TODO(sclevine): revert on failure + + servicePath := filepath.Join(linkDir, serviceDir, updateServiceName) + err := writeTemplate(servicePath, updateServiceTemplate, linkDir, dataDir) + if err != nil { + return trace.Wrap(err) + } + timerPath := filepath.Join(linkDir, serviceDir, updateTimerName) + err = writeTemplate(timerPath, updateTimerTemplate, linkDir, dataDir) + if err != nil { + return trace.Wrap(err) + } + + return nil +} + +func writeTemplate(path, t, linkDir, dataDir string) error { + if err := os.MkdirAll(filepath.Dir(path), systemDirMode); err != nil { + return trace.Wrap(err) + } + opts := []renameio.Option{ + renameio.WithPermissions(configFileMode), + renameio.WithExistingPermissions(), + } + f, err := renameio.NewPendingFile(path, opts...) + if err != nil { + return trace.Wrap(err) + } + defer f.Cleanup() + + tmpl, err := template.New(filepath.Base(path)).Parse(t) + if err != nil { + return trace.Wrap(err) + } + err = tmpl.Execute(f, struct { + LinkDir string + DataDir string + }{linkDir, dataDir}) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(f.CloseAtomicallyReplace()) +} diff --git a/lib/autoupdate/agent/config_test.go b/lib/autoupdate/agent/config_test.go new file mode 100644 index 0000000000000..16cbdb5374fb6 --- /dev/null +++ b/lib/autoupdate/agent/config_test.go @@ -0,0 +1,65 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + libdefaults "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils/golden" +) + +func TestWriteConfigFiles(t *testing.T) { + t.Parallel() + linkDir := t.TempDir() + dataDir := t.TempDir() + err := writeConfigFiles(linkDir, dataDir) + require.NoError(t, err) + + for _, p := range []string{ + filepath.Join(linkDir, serviceDir, updateServiceName), + filepath.Join(linkDir, serviceDir, updateTimerName), + } { + t.Run(filepath.Base(p), func(t *testing.T) { + data, err := os.ReadFile(p) + require.NoError(t, err) + data = replaceValues(data, map[string]string{ + DefaultLinkDir: linkDir, + libdefaults.DataDir: dataDir, + }) + if golden.ShouldSet() { + golden.Set(t, data) + } + require.Equal(t, string(golden.Get(t)), string(data)) + }) + } +} + +func replaceValues(data []byte, m map[string]string) []byte { + for k, v := range m { + data = bytes.ReplaceAll(data, []byte(v), + []byte(k)) + } + return data +} diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go index 957e90779c2ab..2d31d26fd8262 100644 --- a/lib/autoupdate/agent/installer.go +++ b/lib/autoupdate/agent/installer.go @@ -55,11 +55,15 @@ const ( systemDirMode = 0755 ) -var ( +const ( // serviceDir contains the relative path to the Teleport SystemD service dir. - serviceDir = filepath.Join("lib", "systemd", "system") + serviceDir = "lib/systemd/system" // serviceName contains the name of the Teleport SystemD service file. serviceName = "teleport.service" + // updateServiceName contains the name of the Teleport Update Systemd service + updateServiceName = "teleport-update.service" + // updateTimerName contains the name of the Teleport Update Systemd timer + updateTimerName = "teleport-update.timer" ) // LocalInstaller manages the creation and removal of installations @@ -539,7 +543,7 @@ func (li *LocalInstaller) forceLinks(ctx context.Context, binDir, svcDir string) dst := filepath.Join(li.LinkServiceDir, serviceName) orig, err := forceCopy(dst, src, maxServiceFileSize) if err != nil && !errors.Is(err, os.ErrExist) { - return revert, trace.Errorf("failed to create file for %s: %w", serviceName, err) + return revert, trace.Errorf("failed to write file %s: %w", serviceName, err) } if orig != nil { revertFiles = append(revertFiles, *orig) @@ -782,13 +786,5 @@ func (li *LocalInstaller) isLinked(versionDir string) (bool, error) { return true, nil } } - linkData, err := readFileN(filepath.Join(li.LinkServiceDir, serviceName), maxServiceFileSize) - if err != nil { - return false, nil - } - versionData, err := readFileN(filepath.Join(versionDir, serviceDir, serviceName), maxServiceFileSize) - if err != nil { - return false, nil - } - return bytes.Equal(linkData, versionData), nil + return false, nil } diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index eba70aa56a690..082e61156369b 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -25,25 +25,51 @@ import ( "log/slog" "os" "os/exec" + "strconv" + "syscall" + "time" "github.com/gravitational/trace" + "golang.org/x/sync/errgroup" ) -// SystemdService manages a Teleport systemd service. +const ( + // crashMonitorInterval is the polling interval for determining restart times from LastRestartPath. + crashMonitorInterval = 2 * time.Second + // minRunningIntervalsBeforeStable is the number of consecutive intervals with the same running PID detect + // before the service is determined stable. + minRunningIntervalsBeforeStable = 6 + // maxCrashesBeforeFailure is the number of total crashes detected before the service is marked as crash-looping. + maxCrashesBeforeFailure = 2 + // crashMonitorTimeout + crashMonitorTimeout = 30 * time.Second +) + +// log keys +const ( + unitKey = "unit" +) + +// SystemdService manages a systemd service (e.g., teleport or teleport-update). type SystemdService struct { // ServiceName specifies the systemd service name. ServiceName string + // PIDPath is a path to a file containing the service's PID. + PIDPath string // Log contains a logger. Log *slog.Logger } -// Reload a systemd service. +// Reload the systemd service. // Attempts a graceful reload before a hard restart. // See Process interface for more details. func (s SystemdService) Reload(ctx context.Context) error { + // TODO(sclevine): allow server to force restart instead of reload + if err := s.checkSystem(ctx); err != nil { return trace.Wrap(err) } + // Command error codes < 0 indicate that we are unable to run the command. // Errors from s.systemctl are logged along with stderr and stdout (debug only). @@ -55,30 +81,165 @@ func (s SystemdService) Reload(ctx context.Context) error { case code < 0: return trace.Errorf("unable to determine if systemd service is active") case code > 0: - s.Log.WarnContext(ctx, "Teleport systemd service not running.") + s.Log.WarnContext(ctx, "Systemd service not running.", unitKey, s.ServiceName) return trace.Wrap(ErrNotNeeded) } + + // Get initial PID for crash monitoring. + + initPID, err := readInt(s.PIDPath) + if errors.Is(err, os.ErrNotExist) { + s.Log.InfoContext(ctx, "No existing process detected. Skipping crash monitoring.", unitKey, s.ServiceName) + } else if err != nil { + s.Log.ErrorContext(ctx, "Error reading initial PID value. Skipping crash monitoring.", unitKey, s.ServiceName, errorKey, err) + } + // Attempt graceful reload of running service. code = s.systemctl(ctx, slog.LevelError, "reload", s.ServiceName) switch { case code < 0: - return trace.Errorf("unable to attempt reload of Teleport systemd service") + return trace.Errorf("unable to reload systemd service") case code > 0: // Graceful reload fails, try hard restart. code = s.systemctl(ctx, slog.LevelError, "try-restart", s.ServiceName) if code != 0 { - return trace.Errorf("hard restart of Teleport systemd service failed") + return trace.Errorf("hard restart of systemd service failed") } - s.Log.WarnContext(ctx, "Teleport ungracefully restarted. Connections potentially dropped.") + s.Log.WarnContext(ctx, "Service ungracefully restarted. Connections potentially dropped.", unitKey, s.ServiceName) default: - s.Log.InfoContext(ctx, "Teleport gracefully reloaded.") + s.Log.InfoContext(ctx, "Gracefully reloaded.", unitKey, s.ServiceName) + } + if initPID != 0 { + s.Log.InfoContext(ctx, "Monitoring PID file to detect crashes.", unitKey, s.ServiceName) + return trace.Wrap(s.monitor(ctx, initPID)) } + return nil +} - // TODO(sclevine): Ensure restart was successful and verify healthcheck. +// monitor for the started process to ensure it's running by polling PIDFile. +// This function detects several types of crashes while minimizing its own runtime during updates. +// For example, the process may crash by failing to fork (non-running PID), or looping (repeatedly changing PID), +// or getting stuck on quit (no change in PID). +// initPID is the PID before the restart operation has been issued. +func (s SystemdService) monitor(ctx context.Context, initPID int) error { + ctx, cancel := context.WithTimeout(ctx, crashMonitorTimeout) + defer cancel() + tickC := time.NewTicker(crashMonitorInterval).C + pidC := make(chan int) + g := &errgroup.Group{} + g.Go(func() error { + return tickFile(ctx, s.PIDPath, pidC, tickC) + }) + err := s.waitForStablePID(ctx, minRunningIntervalsBeforeStable, maxCrashesBeforeFailure, + initPID, pidC, func(pid int) error { + p, err := os.FindProcess(pid) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(p.Signal(syscall.Signal(0))) + }) + cancel() + if err := g.Wait(); err != nil { + s.Log.ErrorContext(ctx, "Error monitoring for crashing process.", errorKey, err, unitKey, s.ServiceName) + } + return trace.Wrap(err) +} + +// monitorRestarts receives restart times on timeCh. +// Each restart time that differs from the preceding restart time counts as a restart. +// If maxRestarts is exceeded, monitorRestarts returns an error. +// Each restart time that matches the proceeding restart time counts as a clean reading. +// If minClean is reached before maxRestarts is exceeded, monitorRestarts runs nil. +func (s SystemdService) waitForStablePID(ctx context.Context, minStable, maxCrashes, baselinePID int, pidC <-chan int, findPID func(pid int) error) error { + pid := baselinePID + var last, stale int + var crashes int + for stable := 0; stable < minStable; stable++ { + select { + case <-ctx.Done(): + return ctx.Err() + case p := <-pidC: + last = pid + pid = p + } + // A "crash" is defined as a transition away from a new (non-baseline) PID, or + // an interval where the current PID remains non-running since the last check. + if (last != 0 && pid != last && last != baselinePID) || + (stale != 0 && pid == stale && last == stale) { + crashes++ + } + if crashes > maxCrashes { + return trace.Errorf("detected crashing process") + } + + // PID can only be stable if it is a real PID that is not new, has changed at least once, + // and hasn't been observed as missing. + if pid == 0 || + pid == baselinePID || + pid == stale || + pid != last { + stable = -1 + continue + } + err := findPID(pid) + // A stale PID most likely indicates that the process forked and crashed without systemd noticing. + // There is a small chance that we read the PID file before systemd removed it. + // Note: we only perform this check on PIDs that survive one iteration. + if errors.Is(err, os.ErrProcessDone) || + errors.Is(err, syscall.ESRCH) { + if pid != stale && + pid != baselinePID { + stale = pid + s.Log.WarnContext(ctx, "Detected stale PID.", unitKey, s.ServiceName, "pid", stale) + } + stable = -1 + continue + } + if err != nil { + return trace.Wrap(err) + } + } return nil } +func readInt(path string) (int, error) { + p, err := readFileN(path, 32) + if err != nil { + return 0, trace.Wrap(err) + } + i, err := strconv.ParseInt(string(bytes.TrimSpace(p)), 10, 64) + if err != nil { + return 0, trace.Wrap(err) + } + return int(i), nil +} + +// tickFile reads the current time on tickC, and outputs the last read int from path on ch for each received tick. +// If the path cannot be read, tickFile sends 0 on ch. +// Any error from the last attempt to read path is returned when ctx is canceled, unless the error is os.ErrNotExist. +func tickFile(ctx context.Context, path string, ch chan<- int, tickC <-chan time.Time) error { + var err error + for { + // two select statements -> never skip reads + select { + case <-tickC: + case <-ctx.Done(): + return err + } + var t int + t, err = readInt(path) + if errors.Is(err, os.ErrNotExist) { + err = nil + } + select { + case ch <- t: + case <-ctx.Done(): + return err + } + } +} + // Sync systemd service configuration by running systemctl daemon-reload. // See Process interface for more details. func (s SystemdService) Sync(ctx context.Context) error { @@ -89,6 +250,24 @@ func (s SystemdService) Sync(ctx context.Context) error { if code != 0 { return trace.Errorf("unable to reload systemd configuration") } + s.Log.InfoContext(ctx, "Systemd configuration synced.") + return nil +} + +// Enable the systemd service. +func (s SystemdService) Enable(ctx context.Context, now bool) error { + if err := s.checkSystem(ctx); err != nil { + return trace.Wrap(err) + } + args := []string{"enable", s.ServiceName} + if now { + args = append(args, "--now") + } + code := s.systemctl(ctx, slog.LevelError, args...) + if code != 0 { + return trace.Errorf("unable to enable systemd service") + } + s.Log.InfoContext(ctx, "Service enabled.", unitKey, s.ServiceName) return nil } @@ -106,9 +285,42 @@ func (s SystemdService) checkSystem(ctx context.Context) error { // Output sent to stdout is logged at debug level. // Output sent to stderr is logged at the level specified by errLevel. func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args ...string) int { - cmd := exec.CommandContext(ctx, "systemctl", args...) - stderr := &lineLogger{ctx: ctx, log: s.Log, level: errLevel} - stdout := &lineLogger{ctx: ctx, log: s.Log, level: slog.LevelDebug} + cmd := &localExec{ + Log: s.Log, + ErrLevel: errLevel, + OutLevel: slog.LevelDebug, + } + code, err := cmd.Run(ctx, "systemctl", args...) + if err == nil { + return code + } + if code >= 0 { + s.Log.Log(ctx, errLevel, "Error running systemctl.", + "args", args, "code", code) + return code + } + s.Log.Log(ctx, errLevel, "Unable to run systemctl.", + "args", args, "code", code, errorKey, err) + return code +} + +// localExec runs a command locally, logging any output. +type localExec struct { + // Log contains a slog logger. + // Defaults to slog.Default() if nil. + Log *slog.Logger + // ErrLevel is the log level for stderr. + ErrLevel slog.Level + // OutLevel is the log level for stdout. + OutLevel slog.Level +} + +// Run the command. Same arguments as exec.CommandContext. +// Outputs the status code, or -1 if out-of-range or unstarted. +func (c *localExec) Run(ctx context.Context, name string, args ...string) (int, error) { + cmd := exec.CommandContext(ctx, name, args...) + stderr := &lineLogger{ctx: ctx, log: c.Log, level: c.ErrLevel, prefix: "[stderr] "} + stdout := &lineLogger{ctx: ctx, log: c.Log, level: c.OutLevel, prefix: "[stdout] "} cmd.Stderr = stderr cmd.Stdout = stdout err := cmd.Run() @@ -122,24 +334,23 @@ func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args if code == 255 { code = -1 } - if err != nil { - s.Log.Log(ctx, errLevel, "Failed to run systemctl.", - "args", args, - "code", code, - "error", err) - } - return code + return code, trace.Wrap(err) } // lineLogger logs each line written to it. type lineLogger struct { - ctx context.Context - log *slog.Logger - level slog.Level + ctx context.Context + log *slog.Logger + level slog.Level + prefix string last bytes.Buffer } +func (w *lineLogger) out(s string) { + w.log.Log(w.ctx, w.level, w.prefix+s) //nolint:sloglint // msg cannot be constant +} + func (w *lineLogger) Write(p []byte) (n int, err error) { lines := bytes.Split(p, []byte("\n")) // Finish writing line @@ -153,13 +364,13 @@ func (w *lineLogger) Write(p []byte) (n int, err error) { } // Newline found, log line - w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant + w.out(w.last.String()) n += 1 w.last.Reset() // Log lines that are already newline-terminated for _, line := range lines[:len(lines)-1] { - w.log.Log(w.ctx, w.level, string(line)) //nolint:sloglint // msg cannot be constant + w.out(string(line)) n += len(line) + 1 } @@ -174,6 +385,6 @@ func (w *lineLogger) Flush() { if w.last.Len() == 0 { return } - w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint // msg cannot be constant + w.out(w.last.String()) w.last.Reset() } diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go index 5ffa70dd0091e..c558a7539831a 100644 --- a/lib/autoupdate/agent/process_test.go +++ b/lib/autoupdate/agent/process_test.go @@ -21,8 +21,13 @@ package agent import ( "bytes" "context" + "errors" + "fmt" "log/slog" + "os" + "path/filepath" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -69,3 +74,266 @@ func msgOnly(_ []string, a slog.Attr) slog.Attr { } return slog.Attr{Key: a.Key, Value: a.Value} } + +func TestWaitForStablePID(t *testing.T) { + t.Parallel() + + svc := &SystemdService{ + Log: slog.Default(), + } + + for _, tt := range []struct { + name string + ticks []int + baseline int + minStable int + maxCrashes int + findErrs map[int]error + + errored bool + canceled bool + }{ + { + name: "immediate restart", + ticks: []int{2, 2}, + baseline: 1, + minStable: 1, + maxCrashes: 1, + }, + { + name: "zero stable", + }, + { + name: "immediate crash", + ticks: []int{2, 3}, + baseline: 1, + minStable: 1, + maxCrashes: 0, + errored: true, + }, + { + name: "no changes times out", + ticks: []int{1, 1, 1, 1}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + canceled: true, + }, + { + name: "baseline restart", + ticks: []int{2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + }, + { + name: "one restart then stable", + ticks: []int{1, 2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + }, + { + name: "two restarts then stable", + ticks: []int{1, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + }, + { + name: "three restarts then stable", + ticks: []int{1, 2, 3, 4, 4, 4, 4}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + }, + { + name: "too many restarts excluding baseline", + ticks: []int{1, 2, 3, 4, 5}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + errored: true, + }, + { + name: "too many restarts including baseline", + ticks: []int{1, 2, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, + }, + { + name: "too many restarts slow", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, + }, + { + name: "too many restarts after stable", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + }, + { + name: "stable after too many restarts", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, + }, + { + name: "cancel", + ticks: []int{1, 1, 1}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + canceled: true, + }, + { + name: "stale PID crash", + ticks: []int{2, 2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: os.ErrProcessDone, + }, + errored: true, + }, + { + name: "stale PID but fixed", + ticks: []int{2, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: os.ErrProcessDone, + }, + }, + { + name: "error PID", + ticks: []int{2, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: errors.New("bad"), + }, + errored: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + ch := make(chan int) + go func() { + defer cancel() // always quit after last tick + for _, tick := range tt.ticks { + ch <- tick + } + }() + err := svc.waitForStablePID(ctx, tt.minStable, tt.maxCrashes, + tt.baseline, ch, func(pid int) error { + return tt.findErrs[pid] + }) + require.Equal(t, tt.canceled, errors.Is(err, context.Canceled)) + if !tt.canceled { + require.Equal(t, tt.errored, err != nil) + } + }) + } +} + +func TestTickFile(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + ticks []int + errored bool + }{ + { + name: "consistent", + ticks: []int{1, 1, 1}, + errored: false, + }, + { + name: "divergent", + ticks: []int{1, 2, 3}, + errored: false, + }, + { + name: "start error", + ticks: []int{-1, 1, 1}, + errored: false, + }, + { + name: "ephemeral error", + ticks: []int{1, -1, 1}, + errored: false, + }, + { + name: "end error", + ticks: []int{1, 1, -1}, + errored: true, + }, + { + name: "start missing", + ticks: []int{0, 1, 1}, + errored: false, + }, + { + name: "ephemeral missing", + ticks: []int{1, 0, 1}, + errored: false, + }, + { + name: "end missing", + ticks: []int{1, 1, 0}, + errored: false, + }, + { + name: "cancel-only", + errored: false, + }, + } { + t.Run(tt.name, func(t *testing.T) { + filePath := filepath.Join(t.TempDir(), "file") + + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + tickC := make(chan time.Time) + ch := make(chan int) + + go func() { + defer cancel() // always quit after last tick or fail + for _, tick := range tt.ticks { + _ = os.RemoveAll(filePath) + switch { + case tick > 0: + err := os.WriteFile(filePath, []byte(fmt.Sprintln(tick)), os.ModePerm) + require.NoError(t, err) + case tick < 0: + err := os.Mkdir(filePath, os.ModePerm) + require.NoError(t, err) + } + tickC <- time.Now() + res := <-ch + if tick < 0 { + tick = 0 + } + require.Equal(t, tick, res) + } + }() + err := tickFile(ctx, filePath, ch, tickC) + require.Equal(t, tt.errored, err != nil) + }) + } +} diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden new file mode 100644 index 0000000000000..185b4f07a1aa9 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.service.golden @@ -0,0 +1,7 @@ +# teleport-update +[Unit] +Description=Teleport auto-update service + +[Service] +Type=oneshot +ExecStart=/usr/local/bin/teleport-update update diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden new file mode 100644 index 0000000000000..acca095d9825f --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/teleport-update.timer.golden @@ -0,0 +1,11 @@ +# teleport-update +[Unit] +Description=Teleport auto-update timer unit + +[Timer] +OnActiveSec=1m +OnUnitActiveSec=5m +RandomizedDelaySec=1m + +[Install] +WantedBy=teleport.service diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index 9625481df2cd2..49dfe40fd27da 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -27,6 +27,7 @@ import ( "log/slog" "net/http" "os" + "os/exec" "path/filepath" "strings" "time" @@ -46,6 +47,10 @@ const ( DefaultLinkDir = "/usr/local" // DefaultSystemDir is the location where packaged Teleport binaries and services are installed. DefaultSystemDir = "/usr/local/teleport-system" + // VersionsDirName specifies the name of the subdirectory inside the Teleport data dir for storing Teleport versions. + VersionsDirName = "versions" + // BinaryName specifies the name of the updater binary. + BinaryName = "teleport-update" ) const ( @@ -136,16 +141,20 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { if cfg.SystemDir == "" { cfg.SystemDir = DefaultSystemDir } - if cfg.VersionsDir == "" { - cfg.VersionsDir = filepath.Join(libdefaults.DataDir, "versions") + if cfg.DataDir == "" { + cfg.DataDir = libdefaults.DataDir + } + installDir := filepath.Join(cfg.DataDir, VersionsDirName) + if err := os.MkdirAll(installDir, systemDirMode); err != nil { + return nil, trace.Errorf("failed to create install directory: %w", err) } return &Updater{ Log: cfg.Log, Pool: certPool, InsecureSkipVerify: cfg.InsecureSkipVerify, - ConfigPath: filepath.Join(cfg.VersionsDir, updateConfigName), + ConfigPath: filepath.Join(installDir, updateConfigName), Installer: &LocalInstaller{ - InstallDir: cfg.VersionsDir, + InstallDir: installDir, LinkBinDir: filepath.Join(cfg.LinkDir, "bin"), // For backwards-compatibility with symlinks created by package-based installs, we always // link into /lib/systemd/system, even though, e.g., /usr/local/lib/systemd/system would work. @@ -159,8 +168,24 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { }, Process: &SystemdService{ ServiceName: "teleport.service", + PIDPath: "/run/teleport.pid", Log: cfg.Log, }, + Setup: func(ctx context.Context) error { + name := filepath.Join(cfg.LinkDir, "bin", BinaryName) + if cfg.SelfSetup { + name = "/proc/self/exe" + } + cmd := exec.CommandContext(ctx, name, + "--data-dir", cfg.DataDir, + "--link-dir", cfg.LinkDir, + "setup") + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + cfg.Log.InfoContext(ctx, "Executing new teleport-update binary to update configuration.") + defer cfg.Log.InfoContext(ctx, "Finished executing new teleport-update binary.") + return trace.Wrap(cmd.Run()) + }, }, nil } @@ -174,12 +199,14 @@ type LocalUpdaterConfig struct { // DownloadTimeout is a timeout for file download requests. // Defaults to no timeout. DownloadTimeout time.Duration - // VersionsDir for installing Teleport (usually /var/lib/teleport/versions). - VersionsDir string + // DataDir for Teleport (usually /var/lib/teleport). + DataDir string // LinkDir for installing Teleport (usually /usr/local). LinkDir string // SystemDir for package-installed Teleport installations (usually /usr/local/teleport-system). SystemDir string + // SelfSetup mode for using the current version of the teleport-update to setup the update service. + SelfSetup bool } // Updater implements the agent-local logic for Teleport agent auto-updates. @@ -196,6 +223,8 @@ type Updater struct { Installer Installer // Process manages a running instance of Teleport. Process Process + // Setup installs the Teleport updater service using the linked installation. + Setup func(ctx context.Context) error } // Installer provides an API for installing Teleport agents. @@ -336,6 +365,8 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { return trace.Errorf("agent version not available from Teleport cluster") } + u.Log.InfoContext(ctx, "Initiating initial update.", targetVersionKey, targetVersion, activeVersionKey, cfg.Status.ActiveVersion) + if err := u.update(ctx, cfg, targetVersion, flags); err != nil { return trace.Wrap(err) } @@ -476,7 +507,7 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s } } - // Install the desired version (or validate existing installation) + // Install and link the desired version (or validate existing installation) template := cfg.Spec.URLTemplate if template == "" { @@ -486,20 +517,38 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s if err != nil { return trace.Errorf("failed to install: %w", err) } + + // TODO(slevine): if the target version has fewer binaries, this will + // leave old binaries linked. This may prevent the installation from + // being removed. To fix this, we should look for orphaned binaries + // and remove them. + revert, err := u.Installer.Link(ctx, targetVersion) if err != nil { return trace.Errorf("failed to link: %w", err) } + // Verify that the linked installation contains a valid updater binary, + // and use it to update the updater's service files. + + if err := u.Setup(ctx); err != nil { + if ok := revert(ctx); !ok { + u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks. Installation likely broken.") + } + return trace.Errorf("failed to setup updater: %w", err) + } + // If we fail to revert after this point, the next update/enable will // fix the link to restore the active version. // Sync process configuration after linking. - if err := u.Process.Sync(ctx); err != nil { - if errors.Is(err, context.Canceled) { - return trace.Errorf("sync canceled") - } + err = u.Process.Sync(ctx) + if errors.Is(err, ErrNotSupported) { + u.Log.WarnContext(ctx, "Not syncing systemd configuration because systemd is not running.") + } else if errors.Is(err, context.Canceled) { + return trace.Errorf("sync canceled") + } else if err != nil { // If sync fails, we may have left the host in a bad state, so we revert linking and re-Sync. u.Log.ErrorContext(ctx, "Reverting symlinks due to invalid configuration.") if ok := revert(ctx); !ok { @@ -517,10 +566,14 @@ func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, targetVersion s if cfg.Status.ActiveVersion != targetVersion { u.Log.InfoContext(ctx, "Target version successfully installed.", targetVersionKey, targetVersion) - if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) { - if errors.Is(err, context.Canceled) { - return trace.Errorf("reload canceled") - } + err := u.Process.Reload(ctx) + if errors.Is(err, context.Canceled) { + return trace.Errorf("reload canceled") + } + if err != nil && + !errors.Is(err, ErrNotNeeded) && // no output if restart not needed + !errors.Is(err, ErrNotSupported) { // already logged above for Sync + // If reloading Teleport at the new version fails, revert, resync, and reload. u.Log.ErrorContext(ctx, "Reverting symlinks due to failed restart.") if ok := revert(ctx); !ok { @@ -608,7 +661,11 @@ func validateConfigSpec(spec *UpdateSpec, override OverrideConfig) error { if override.Group != "" { spec.Group = override.Group } - if override.URLTemplate != "" { + switch override.URLTemplate { + case "": + case "default": + spec.URLTemplate = "" + default: spec.URLTemplate = override.URLTemplate } if spec.URLTemplate != "" && diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go index 1197ac3d5a795..8f919762f53fb 100644 --- a/lib/autoupdate/agent/updater_test.go +++ b/lib/autoupdate/agent/updater_test.go @@ -83,7 +83,13 @@ func TestUpdater_Disable(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") + cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml") + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + DataDir: dir, + }) + require.NoError(t, err) // Create config file only if provided in test case if tt.cfg != nil { @@ -92,11 +98,7 @@ func TestUpdater_Disable(t *testing.T) { err = os.WriteFile(cfgPath, b, 0600) require.NoError(t, err) } - updater, err := NewLocalUpdater(LocalUpdaterConfig{ - InsecureSkipVerify: true, - VersionsDir: dir, - }) - require.NoError(t, err) + err = updater.Disable(context.Background()) if tt.errMatch != "" { require.Error(t, err) @@ -142,6 +144,7 @@ func TestUpdater_Update(t *testing.T) { syncCalls int reloadCalls int revertCalls int + setupCalls int errMatch string }{ { @@ -166,6 +169,7 @@ func TestUpdater_Update(t *testing.T) { requestGroup: "group", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "updates disabled during window", @@ -295,6 +299,7 @@ func TestUpdater_Update(t *testing.T) { removedVersion: "backup-version", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "backup version kept when no change", @@ -338,6 +343,7 @@ func TestUpdater_Update(t *testing.T) { removedVersion: "backup-version", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "invalid metadata", @@ -368,6 +374,7 @@ func TestUpdater_Update(t *testing.T) { syncCalls: 2, reloadCalls: 0, revertCalls: 1, + setupCalls: 1, errMatch: "sync error", }, { @@ -394,6 +401,7 @@ func TestUpdater_Update(t *testing.T) { syncCalls: 2, reloadCalls: 2, revertCalls: 1, + setupCalls: 1, errMatch: "reload error", }, } @@ -419,7 +427,13 @@ func TestUpdater_Update(t *testing.T) { t.Cleanup(server.Close) dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") + cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml") + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + DataDir: dir, + }) + require.NoError(t, err) // Create config file only if provided in test case if tt.cfg != nil { @@ -430,12 +444,6 @@ func TestUpdater_Update(t *testing.T) { require.NoError(t, err) } - updater, err := NewLocalUpdater(LocalUpdaterConfig{ - InsecureSkipVerify: true, - VersionsDir: dir, - }) - require.NoError(t, err) - var ( installedVersion string installedTemplate string @@ -481,6 +489,12 @@ func TestUpdater_Update(t *testing.T) { }, } + var setupCalls int + updater.Setup = func(_ context.Context) error { + setupCalls++ + return nil + } + ctx := context.Background() err = updater.Update(ctx) if tt.errMatch != "" { @@ -498,6 +512,7 @@ func TestUpdater_Update(t *testing.T) { require.Equal(t, tt.syncCalls, syncCalls) require.Equal(t, tt.reloadCalls, reloadCalls) require.Equal(t, tt.revertCalls, revertCalls) + require.Equal(t, tt.setupCalls, setupCalls) if tt.cfg == nil { _, err := os.Stat(cfgPath) @@ -594,7 +609,13 @@ func TestUpdater_LinkPackage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") + cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml") + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + DataDir: dir, + }) + require.NoError(t, err) // Create config file only if provided in test case if tt.cfg != nil { @@ -604,12 +625,6 @@ func TestUpdater_LinkPackage(t *testing.T) { require.NoError(t, err) } - updater, err := NewLocalUpdater(LocalUpdaterConfig{ - InsecureSkipVerify: true, - VersionsDir: dir, - }) - require.NoError(t, err) - var tryLinkSystemCalls int updater.Installer = &testInstaller{ FuncTryLinkSystem: func(_ context.Context) error { @@ -659,6 +674,7 @@ func TestUpdater_Enable(t *testing.T) { syncCalls int reloadCalls int revertCalls int + setupCalls int errMatch string }{ { @@ -681,6 +697,7 @@ func TestUpdater_Enable(t *testing.T) { requestGroup: "group", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "config from user", @@ -706,6 +723,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion: "new-version", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "already enabled", @@ -725,6 +743,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "insecure URL", @@ -766,6 +785,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 0, + setupCalls: 1, }, { name: "backup version removed on install", @@ -784,6 +804,7 @@ func TestUpdater_Enable(t *testing.T) { removedVersion: "backup-version", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "backup version kept for validation", @@ -802,6 +823,7 @@ func TestUpdater_Enable(t *testing.T) { removedVersion: "", syncCalls: 1, reloadCalls: 0, + setupCalls: 1, }, { name: "config does not exist", @@ -811,6 +833,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "FIPS and Enterprise flags", @@ -820,6 +843,7 @@ func TestUpdater_Enable(t *testing.T) { linkedVersion: "16.3.0", syncCalls: 1, reloadCalls: 1, + setupCalls: 1, }, { name: "invalid metadata", @@ -836,6 +860,7 @@ func TestUpdater_Enable(t *testing.T) { syncCalls: 2, reloadCalls: 0, revertCalls: 1, + setupCalls: 1, errMatch: "sync error", }, { @@ -848,6 +873,7 @@ func TestUpdater_Enable(t *testing.T) { syncCalls: 2, reloadCalls: 2, revertCalls: 1, + setupCalls: 1, errMatch: "reload error", }, } @@ -855,7 +881,13 @@ func TestUpdater_Enable(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") + cfgPath := filepath.Join(dir, VersionsDirName, "update.yaml") + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + DataDir: dir, + }) + require.NoError(t, err) // Create config file only if provided in test case if tt.cfg != nil { @@ -886,12 +918,6 @@ func TestUpdater_Enable(t *testing.T) { tt.userCfg.Proxy = strings.TrimPrefix(server.URL, "https://") } - updater, err := NewLocalUpdater(LocalUpdaterConfig{ - InsecureSkipVerify: true, - VersionsDir: dir, - }) - require.NoError(t, err) - var ( installedVersion string installedTemplate string @@ -936,6 +962,11 @@ func TestUpdater_Enable(t *testing.T) { return tt.reloadErr }, } + var setupCalls int + updater.Setup = func(_ context.Context) error { + setupCalls++ + return nil + } ctx := context.Background() err = updater.Enable(ctx, tt.userCfg) @@ -954,6 +985,7 @@ func TestUpdater_Enable(t *testing.T) { require.Equal(t, tt.syncCalls, syncCalls) require.Equal(t, tt.reloadCalls, reloadCalls) require.Equal(t, tt.revertCalls, revertCalls) + require.Equal(t, tt.setupCalls, setupCalls) if tt.cfg == nil && err != nil { _, err := os.Stat(cfgPath) diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go index d559ad3e75cdd..c2d8f07165d25 100644 --- a/tool/teleport-update/main.go +++ b/tool/teleport-update/main.go @@ -21,6 +21,7 @@ package main import ( "context" "errors" + "fmt" "log/slog" "os" "os/signal" @@ -58,10 +59,8 @@ const ( ) const ( - // versionsDirName specifies the name of the subdirectory inside of the Teleport data dir for storing Teleport versions. - versionsDirName = "versions" - // lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution. - lockFileName = ".lock" + // lockFileName specifies the name of the file containing the flock lock preventing concurrent updater execution. + lockFileName = ".update-lock" ) var plog = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentUpdater) @@ -84,6 +83,8 @@ type cliConfig struct { DataDir string // LinkDir for linking binaries and systemd services LinkDir string + // SelfSetup mode for using the current version of the teleport-update to setup the update service. + SelfSetup bool } func Run(args []string) error { @@ -91,7 +92,7 @@ func Run(args []string) error { ctx := context.Background() ctx, _ = signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) - app := libutils.InitCLIParser("teleport-update", appHelp).Interspersed(false) + app := libutils.InitCLIParser(autoupdate.BinaryName, appHelp).Interspersed(false) app.Flag("debug", "Verbose logging to stdout."). Short('d').BoolVar(&ccfg.Debug) app.Flag("data-dir", "Teleport data directory. Access to this directory should be limited."). @@ -103,7 +104,7 @@ func Run(args []string) error { app.HelpFlag.Short('h') - versionCmd := app.Command("version", "Print the version of your teleport-updater binary.") + versionCmd := app.Command("version", fmt.Sprintf("Print the version of your %s binary.", autoupdate.BinaryName)) enableCmd := app.Command("enable", "Enable agent auto-updates and perform initial update.") enableCmd.Flag("proxy", "Address of the Teleport Proxy."). @@ -114,14 +115,21 @@ func Run(args []string) error { Short('t').Envar(templateEnvVar).StringVar(&ccfg.URLTemplate) enableCmd.Flag("force-version", "Force the provided version instead of querying it from the Teleport cluster."). Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion) + enableCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for auto-updates."). + Short('s').Hidden().BoolVar(&ccfg.SelfSetup) // TODO(sclevine): add force-fips and force-enterprise as hidden flags disableCmd := app.Command("disable", "Disable agent auto-updates.") updateCmd := app.Command("update", "Update agent to the latest version, if a new version is available.") + updateCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for auto-updates."). + Short('s').Hidden().BoolVar(&ccfg.SelfSetup) linkCmd := app.Command("link", "Link the system installation of Teleport from the Teleport package, if auto-updates is disabled.") + setupCmd := app.Command("setup", "Write configuration files that run the update subcommand on a timer."). + Hidden() + libutils.UpdateAppUsageTemplate(app, args) command, err := app.Parse(args) if err != nil { @@ -143,6 +151,8 @@ func Run(args []string) error { err = cmdUpdate(ctx, &ccfg) case linkCmd.FullCommand(): err = cmdLink(ctx, &ccfg) + case setupCmd.FullCommand(): + err = cmdSetup(ctx, &ccfg) case versionCmd.FullCommand(): modules.GetModules().PrintVersion() default: @@ -172,12 +182,17 @@ func setupLogger(debug bool, format string) error { // cmdDisable disables updates. func cmdDisable(ctx context.Context, ccfg *cliConfig) error { - versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) - if err := os.MkdirAll(versionsDir, 0755); err != nil { - return trace.Errorf("failed to create versions directory: %w", err) + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + DataDir: ccfg.DataDir, + LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, + SelfSetup: ccfg.SelfSetup, + Log: plog, + }) + if err != nil { + return trace.Errorf("failed to setup updater: %w", err) } - - unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName)) + unlock, err := libutils.FSWriteLock(filepath.Join(ccfg.DataDir, lockFileName)) if err != nil { return trace.Errorf("failed to grab concurrent execution lock: %w", err) } @@ -186,15 +201,6 @@ func cmdDisable(ctx context.Context, ccfg *cliConfig) error { plog.DebugContext(ctx, "Failed to close lock file", "error", err) } }() - updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ - VersionsDir: versionsDir, - LinkDir: ccfg.LinkDir, - SystemDir: autoupdate.DefaultSystemDir, - Log: plog, - }) - if err != nil { - return trace.Errorf("failed to setup updater: %w", err) - } if err := updater.Disable(ctx); err != nil { return trace.Wrap(err) } @@ -203,13 +209,19 @@ func cmdDisable(ctx context.Context, ccfg *cliConfig) error { // cmdEnable enables updates and triggers an initial update. func cmdEnable(ctx context.Context, ccfg *cliConfig) error { - versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) - if err := os.MkdirAll(versionsDir, 0755); err != nil { - return trace.Errorf("failed to create versions directory: %w", err) + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + DataDir: ccfg.DataDir, + LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, + SelfSetup: ccfg.SelfSetup, + Log: plog, + }) + if err != nil { + return trace.Errorf("failed to setup updater: %w", err) } // Ensure enable can't run concurrently. - unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName)) + unlock, err := libutils.FSWriteLock(filepath.Join(ccfg.DataDir, lockFileName)) if err != nil { return trace.Errorf("failed to grab concurrent execution lock: %w", err) } @@ -218,16 +230,6 @@ func cmdEnable(ctx context.Context, ccfg *cliConfig) error { plog.DebugContext(ctx, "Failed to close lock file", "error", err) } }() - - updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ - VersionsDir: versionsDir, - LinkDir: ccfg.LinkDir, - SystemDir: autoupdate.DefaultSystemDir, - Log: plog, - }) - if err != nil { - return trace.Errorf("failed to setup updater: %w", err) - } if err := updater.Enable(ctx, ccfg.OverrideConfig); err != nil { return trace.Wrap(err) } @@ -236,13 +238,18 @@ func cmdEnable(ctx context.Context, ccfg *cliConfig) error { // cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address. func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { - versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) - if err := os.MkdirAll(versionsDir, 0755); err != nil { - return trace.Errorf("failed to create versions directory: %w", err) + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + DataDir: ccfg.DataDir, + LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, + SelfSetup: ccfg.SelfSetup, + Log: plog, + }) + if err != nil { + return trace.Errorf("failed to setup updater: %w", err) } - // Ensure update can't run concurrently. - unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName)) + unlock, err := libutils.FSWriteLock(filepath.Join(ccfg.DataDir, lockFileName)) if err != nil { return trace.Errorf("failed to grab concurrent execution lock: %w", err) } @@ -252,15 +259,6 @@ func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { } }() - updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ - VersionsDir: versionsDir, - LinkDir: ccfg.LinkDir, - SystemDir: autoupdate.DefaultSystemDir, - Log: plog, - }) - if err != nil { - return trace.Errorf("failed to setup updater: %w", err) - } if err := updater.Update(ctx); err != nil { return trace.Wrap(err) } @@ -269,10 +267,19 @@ func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { // cmdLink creates system package links if no version is linked and auto-updates is disabled. func cmdLink(ctx context.Context, ccfg *cliConfig) error { - versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + DataDir: ccfg.DataDir, + LinkDir: ccfg.LinkDir, + SystemDir: autoupdate.DefaultSystemDir, + SelfSetup: ccfg.SelfSetup, + Log: plog, + }) + if err != nil { + return trace.Errorf("failed to setup updater: %w", err) + } // Skip operation and warn if the updater is currently running. - unlock, err := libutils.FSTryReadLock(filepath.Join(versionsDir, lockFileName)) + unlock, err := libutils.FSTryReadLock(filepath.Join(ccfg.DataDir, lockFileName)) if errors.Is(err, libutils.ErrUnsuccessfulLockTry) { plog.WarnContext(ctx, "Updater is currently running. Skipping package linking.") return nil @@ -285,17 +292,18 @@ func cmdLink(ctx context.Context, ccfg *cliConfig) error { plog.DebugContext(ctx, "Failed to close lock file", "error", err) } }() - updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ - VersionsDir: versionsDir, - LinkDir: ccfg.LinkDir, - SystemDir: autoupdate.DefaultSystemDir, - Log: plog, - }) - if err != nil { - return trace.Errorf("failed to setup updater: %w", err) - } + if err := updater.LinkPackage(ctx); err != nil { return trace.Wrap(err) } return nil } + +// cmdSetup writes configuration files that are needed to run teleport-update update. +func cmdSetup(ctx context.Context, ccfg *cliConfig) error { + err := autoupdate.Setup(ctx, plog, ccfg.LinkDir, ccfg.DataDir) + if err != nil { + return trace.Errorf("failed to setup teleport-update service: %w", err) + } + return nil +}