Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
sclevine committed Oct 28, 2024
1 parent 02738ef commit 45b0fbd
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 31 deletions.
36 changes: 27 additions & 9 deletions lib/autoupdate/agent/installer.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ var (
// See utils.Extract for more details on how this list is parsed.
// Paths must use tarball-style / separators (not filepath).
tgzExtractPaths = []utils.ExtractPath{
{Src: "teleport/examples/systemd/teleport.service", Dst: "etc/systemd/teleport.service"},
{Src: "teleport/examples", Skip: true},
{Src: "teleport/install", Skip: true},
{Src: "teleport/README.md", Dst: "share/README.md"},
{Src: "teleport/CHANGELOG.md", Dst: "share/CHANGELOG.md"},
{Src: "teleport/VERSION", Dst: "share/VERSION"},
{Src: "teleport", Dst: "bin"},
{Src: "teleport/examples/systemd/teleport.service", Dst: "etc/systemd/teleport.service", DirMode: 0755},
{Src: "teleport/examples", Skip: true, DirMode: 0755},
{Src: "teleport/install", Skip: true, DirMode: 0755},
{Src: "teleport/README.md", Dst: "share/README.md", DirMode: 0755},
{Src: "teleport/CHANGELOG.md", Dst: "share/CHANGELOG.md", DirMode: 0755},
{Src: "teleport/VERSION", Dst: "share/VERSION", DirMode: 0755},
{Src: "teleport", Dst: "bin", DirMode: 0755},
}

// servicePath contains the path to the Teleport SystemD service within the version directory.
Expand Down Expand Up @@ -119,7 +119,7 @@ func (li *LocalInstaller) Remove(ctx context.Context, version string) error {
// Install a Teleport version directory in InstallDir.
// This function is idempotent.
// See Installer interface for additional specs.
func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error {
func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) (err error) {
versionDir, err := li.versionDir(version)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -175,11 +175,18 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string,
if err != nil {
return trace.Errorf("failed to download teleport: %w", err)
}

// Seek to the start of the tgz file after writing
if _, err := f.Seek(0, io.SeekStart); err != nil {
return trace.Errorf("failed seek to start of download: %w", err)
}

// If interrupted, close the file immediately to stop extracting.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
_ = f.Close()
}()
// Check integrity before decompression
if !bytes.Equal(newSum, pathSum) {
return trace.Errorf("mismatched checksum, download possibly corrupt")
Expand All @@ -193,6 +200,17 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string,
if _, err := f.Seek(0, io.SeekStart); err != nil {
return trace.Errorf("failed seek to start: %w", err)
}

// If there's an error after we start extracting, delete the version dir.
defer func() {
if err != nil {
if err := os.RemoveAll(versionDir); err != nil {
li.Log.WarnContext(ctx, "Failed to cleanup broken version extraction.", "error", err, "dir", versionDir)
}
}
}()

// Extract tgz into version directory.
if err := li.extract(ctx, versionDir, f, n); err != nil {
return trace.Errorf("failed to extract teleport: %w", err)
}
Expand Down
3 changes: 3 additions & 0 deletions lib/autoupdate/agent/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ func (w *lineLogger) Write(p []byte) (n int, err error) {
}

func (w *lineLogger) Flush() {
if w.last.Len() == 0 {
return
}
w.log.Log(w.ctx, w.level, w.last.String()) //nolint:sloglint
w.last.Reset()
}
47 changes: 26 additions & 21 deletions lib/utils/unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ func Extract(r io.Reader, dir string, paths ...ExtractPath) error {
} else if err != nil {
return trace.Wrap(err)
}
if ok := filterHeader(header, paths); !ok {
dirMode, ok := filterHeader(header, paths)
if !ok {
continue
}
err = sanitizeTarPath(header, dir)
if err != nil {
return trace.Wrap(err)
}

if err := extractFile(tarball, header, dir); err != nil {
if err := extractFile(tarball, header, dir, dirMode); err != nil {
return trace.Wrap(err)
}
}
Expand All @@ -74,11 +75,15 @@ type ExtractPath struct {
Src, Dst string
// Skip extracting the Src path and ignore Dst.
Skip bool
// DirMode is the file mode for implicit parent directories in Dst.
DirMode os.FileMode
}

// filterHeader modifies the tar header by filtering it through the ExtractPaths.
// filterHeader returns false if the tar header should be skipped.
func filterHeader(hdr *tar.Header, paths []ExtractPath) (include bool) {
// If no paths are provided, filterHeader assumes the header should be included, and sets
// the mode for implicit parent directories to teleport.DirMaskSharedGroup.
func filterHeader(hdr *tar.Header, paths []ExtractPath) (dirMode os.FileMode, include bool) {
name := path.Clean(hdr.Name)
for _, p := range paths {
src := path.Clean(p.Src)
Expand All @@ -98,14 +103,14 @@ func filterHeader(hdr *tar.Header, paths []ExtractPath) (include bool) {
dst += "/" // tar directory headers end in /
}
hdr.Name = dst
return !p.Skip
return p.DirMode, !p.Skip
default:
// If name is a file, then
// if src is an exact match to the file name, assume src is a file and write directly to dst,
// otherwise, assume src is a directory prefix, and replace that prefix with dst.
if src == name {
hdr.Name = path.Clean(p.Dst)
return !p.Skip
return p.DirMode, !p.Skip
}
if src != "/" {
src += "/" // ensure HasPrefix does not match partial names
Expand All @@ -114,34 +119,34 @@ func filterHeader(hdr *tar.Header, paths []ExtractPath) (include bool) {
continue
}
hdr.Name = path.Join(p.Dst, strings.TrimPrefix(name, src))
return !p.Skip
return p.DirMode, !p.Skip

}
}
return len(paths) == 0
return teleport.DirMaskSharedGroup, len(paths) == 0
}

// extractFile extracts a single file or directory from tarball into dir.
// Uses header to determine the type of item to create
// Based on https://github.com/mholt/archiver
func extractFile(tarball *tar.Reader, header *tar.Header, dir string) error {
func extractFile(tarball *tar.Reader, header *tar.Header, dir string, dirMode os.FileMode) error {
switch header.Typeflag {
case tar.TypeDir:
return withDir(filepath.Join(dir, header.Name), nil)
return withDir(filepath.Join(dir, header.Name), dirMode, nil)
case tar.TypeBlock, tar.TypeChar, tar.TypeReg, tar.TypeFifo:
return writeFile(filepath.Join(dir, header.Name), tarball, header.FileInfo().Mode())
return writeFile(filepath.Join(dir, header.Name), tarball, header.FileInfo().Mode(), dirMode)
case tar.TypeLink:
return writeHardLink(filepath.Join(dir, header.Name), filepath.Join(dir, header.Linkname))
return writeHardLink(filepath.Join(dir, header.Name), filepath.Join(dir, header.Linkname), dirMode)
case tar.TypeSymlink:
return writeSymbolicLink(filepath.Join(dir, header.Name), header.Linkname)
return writeSymbolicLink(filepath.Join(dir, header.Name), header.Linkname, dirMode)
default:
log.Warnf("Unsupported type flag %v for %v.", header.Typeflag, header.Name)
}
return nil
}

// sanitizeTarPath checks that the tar header paths resolve to a subdirectory
// path, and don't contain file paths or links that could escape the tar file
// path, and don't contain file paths or links that could escape the tar fileteleport.DirMaskSharedGroup
// like ../../etc/password.
func sanitizeTarPath(header *tar.Header, dir string) error {
// Sanitize all tar paths resolve to within the destination directory.
Expand All @@ -168,8 +173,8 @@ func sanitizeTarPath(header *tar.Header, dir string) error {
return nil
}

func writeFile(path string, r io.Reader, mode os.FileMode) error {
err := withDir(path, func() error {
func writeFile(path string, r io.Reader, mode, dirMode os.FileMode) error {
err := withDir(path, dirMode, func() error {
// Create file only if it does not exist to prevent overwriting existing
// files (like session recordings).
out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode)
Expand All @@ -182,24 +187,24 @@ func writeFile(path string, r io.Reader, mode os.FileMode) error {
return trace.Wrap(err)
}

func writeSymbolicLink(path string, target string) error {
err := withDir(path, func() error {
func writeSymbolicLink(path, target string, dirMode os.FileMode) error {
err := withDir(path, dirMode, func() error {
err := os.Symlink(target, path)
return trace.ConvertSystemError(err)
})
return trace.Wrap(err)
}

func writeHardLink(path string, target string) error {
err := withDir(path, func() error {
func writeHardLink(path, target string, dirMode os.FileMode) error {
err := withDir(path, dirMode, func() error {
err := os.Link(target, path)
return trace.ConvertSystemError(err)
})
return trace.Wrap(err)
}

func withDir(path string, fn func() error) error {
err := os.MkdirAll(filepath.Dir(path), teleport.DirMaskSharedGroup)
func withDir(path string, mode os.FileMode, fn func() error) error {
err := os.MkdirAll(filepath.Dir(path), mode)
if err != nil {
return trace.ConvertSystemError(err)
}
Expand Down
4 changes: 3 additions & 1 deletion tool/teleport-update/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ const (
versionsDirName = "versions"
// lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution.
lockFileName = ".lock"
// defaultLinkDir is the default location where Teleport binaries and services are linked.
defaultLinkDir = "/usr/local"
)

var plog = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentUpdater)
Expand Down Expand Up @@ -98,7 +100,7 @@ func Run(args []string) error {
app.Flag("log-format", "Controls the format of output logs. Can be `json` or `text`. Defaults to `text`.").
Default(libutils.LogFormatText).EnumVar(&ccfg.LogFormat, libutils.LogFormatJSON, libutils.LogFormatText)
app.Flag("link-dir", "Directory to create system symlinks to binaries and services.").
Default(filepath.Join("usr", "local")).Hidden().StringVar(&ccfg.LinkDir)
Default(defaultLinkDir).Hidden().StringVar(&ccfg.LinkDir)

app.HelpFlag.Short('h')

Expand Down

0 comments on commit 45b0fbd

Please sign in to comment.