diff --git a/common.go b/common.go new file mode 100644 index 0000000..8f5e2c0 --- /dev/null +++ b/common.go @@ -0,0 +1,48 @@ +// Copyright (C) 2014-2015 Docker Inc & Go Authors. All rights reserved. +// Copyright (C) 2017-2025 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "errors" + "os" + "path/filepath" + "strings" + "syscall" +) + +// IsNotExist tells you if err is an error that implies that either the path +// accessed does not exist (or path components don't exist). This is +// effectively a more broad version of [os.IsNotExist]. +func IsNotExist(err error) bool { + // Check that it's not actually an ENOTDIR, which in some cases is a more + // convoluted case of ENOENT (usually involving weird paths). + return errors.Is(err, os.ErrNotExist) || errors.Is(err, syscall.ENOTDIR) || errors.Is(err, syscall.ENOENT) +} + +// errUnsafeRoot is returned if the user provides SecureJoinVFS with a path +// that contains ".." components. +var errUnsafeRoot = errors.New("root path provided to SecureJoin contains '..' components") + +// hasDotDot checks if the path contains ".." components in a platform-agnostic +// way. +func hasDotDot(path string) bool { + // If we are on Windows, strip any volume letters. It turns out that + // C:..\foo may (or may not) be a valid pathname and we need to handle that + // leading "..". + path = stripVolume(path) + // Look for "/../" in the path, but we need to handle leading and trailing + // ".."s by adding separators. Doing this with filepath.Separator is ugly + // so just convert to Unix-style "/" first. + path = filepath.ToSlash(path) + return strings.Contains("/"+path+"/", "/../") +} + +// stripVolume just gets rid of the Windows volume included in a path. Based on +// some godbolt tests, the Go compiler is smart enough to make this a no-op on +// Linux. +func stripVolume(path string) string { + return path[len(filepath.VolumeName(path)):] +} diff --git a/join.go b/join.go index 52cc485..18860c9 100644 --- a/join.go +++ b/join.go @@ -3,10 +3,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build !plan9 + package securejoin import ( - "errors" "os" "path/filepath" "strings" @@ -15,40 +16,6 @@ import ( const maxSymlinkLimit = 255 -// IsNotExist tells you if err is an error that implies that either the path -// accessed does not exist (or path components don't exist). This is -// effectively a more broad version of [os.IsNotExist]. -func IsNotExist(err error) bool { - // Check that it's not actually an ENOTDIR, which in some cases is a more - // convoluted case of ENOENT (usually involving weird paths). - return errors.Is(err, os.ErrNotExist) || errors.Is(err, syscall.ENOTDIR) || errors.Is(err, syscall.ENOENT) -} - -// errUnsafeRoot is returned if the user provides SecureJoinVFS with a path -// that contains ".." components. -var errUnsafeRoot = errors.New("root path provided to SecureJoin contains '..' components") - -// stripVolume just gets rid of the Windows volume included in a path. Based on -// some godbolt tests, the Go compiler is smart enough to make this a no-op on -// Linux. -func stripVolume(path string) string { - return path[len(filepath.VolumeName(path)):] -} - -// hasDotDot checks if the path contains ".." components in a platform-agnostic -// way. -func hasDotDot(path string) bool { - // If we are on Windows, strip any volume letters. It turns out that - // C:..\foo may (or may not) be a valid pathname and we need to handle that - // leading "..". - path = stripVolume(path) - // Look for "/../" in the path, but we need to handle leading and trailing - // ".."s by adding separators. Doing this with filepath.Separator is ugly - // so just convert to Unix-style "/" first. - path = filepath.ToSlash(path) - return strings.Contains("/"+path+"/", "/../") -} - // SecureJoinVFS joins the two given path components (similar to [filepath.Join]) except // that the returned path is guaranteed to be scoped inside the provided root // path (when evaluated). Any symbolic links in the path are evaluated with the diff --git a/join_plan9.go b/join_plan9.go new file mode 100644 index 0000000..8165a78 --- /dev/null +++ b/join_plan9.go @@ -0,0 +1,27 @@ +// Copyright (C) 2014-2015 Docker Inc & Go Authors. All rights reserved. +// Copyright (C) 2017-2025 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import "path/filepath" + +// SecureJoin is equivalent to filepath.Join, as plan9 doesn't have symlinks. +func SecureJoin(root, unsafePath string) (string, error) { + // The root path must not contain ".." components, otherwise when we join + // the subpath we will end up with a weird path. We could work around this + // in other ways but users shouldn't be giving us non-lexical root paths in + // the first place. + if hasDotDot(root) { + return "", errUnsafeRoot + } + + unsafePath = filepath.Join(string(filepath.Separator), unsafePath) + return filepath.Join(root, unsafePath), nil +} + +// SecureJoinVFS is equivalent to filepath.Join, as plan9 doesn't have symlinks. +func SecureJoinVFS(root, unsafePath string, _ VFS) (string, error) { + return SecureJoin(root, unsafePath) +} diff --git a/join_test.go b/join_test.go index 7ec788c..b9894bb 100644 --- a/join_test.go +++ b/join_test.go @@ -13,7 +13,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // TODO: These tests won't work on plan9 because it doesn't have symlinks, and @@ -24,64 +23,9 @@ type input struct { expected string } -func expandedTempDir(t *testing.T) string { - dir := t.TempDir() - dir, err := filepath.EvalSymlinks(dir) - require.NoError(t, err) - return dir -} - -// Test basic handling of symlink expansion. -func TestSymlink(t *testing.T) { - dir := expandedTempDir(t) - - symlink(t, "somepath", filepath.Join(dir, "etc")) - symlink(t, "../../../../../../../../../../../../../etc", filepath.Join(dir, "etclink")) - symlink(t, "/../../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "passwd")) - - rootOrVol := string(filepath.Separator) - if vol := filepath.VolumeName(dir); vol != "" { - rootOrVol = vol + rootOrVol - } - - tc := []input{ - // Make sure that expansion with a root of '/' proceeds in the expected fashion. - {rootOrVol, filepath.Join(dir, "passwd"), filepath.Join(rootOrVol, "etc", "passwd")}, - {rootOrVol, filepath.Join(dir, "etclink"), filepath.Join(rootOrVol, "etc")}, - - {rootOrVol, filepath.Join(dir, "etc"), filepath.Join(dir, "somepath")}, - // Now test scoped expansion. - {dir, "passwd", filepath.Join(dir, "somepath", "passwd")}, - {dir, "etclink", filepath.Join(dir, "somepath")}, - {dir, "etc", filepath.Join(dir, "somepath")}, - {dir, "etc/test", filepath.Join(dir, "somepath", "test")}, - {dir, "etc/test/..", filepath.Join(dir, "somepath")}, - } - - for _, test := range tc { - got, err := SecureJoin(test.root, test.unsafe) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue - } - // This is only for OS X, where /etc is a symlink to /private/etc. In - // principle, SecureJoin(/, pth) is the same as EvalSymlinks(pth) in - // the case where the path exists. - if test.root == "/" { - if expected, err := filepath.EvalSymlinks(test.expected); err == nil { - test.expected = expected - } - } - if got != test.expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) - continue - } - } -} - // In a path without symlinks, SecureJoin is equivalent to Clean+Join. func TestNoSymlink(t *testing.T) { - dir := expandedTempDir(t) + dir := t.TempDir() tc := []input{ {dir, "somepath", filepath.Join(dir, "somepath")}, @@ -112,92 +56,6 @@ func TestNoSymlink(t *testing.T) { } } -// Make sure that .. is **not** expanded lexically. -func TestNonLexical(t *testing.T) { - dir := expandedTempDir(t) - - mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) - mkdirAll(t, filepath.Join(dir, "cousinparent", "cousin"), 0o755) - symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link")) - symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2")) - symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3")) - - for _, test := range []input{ - {dir, "subdir", filepath.Join(dir, "subdir")}, - {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/../test", filepath.Join(dir, "test")}, - // This is the divergence from a simple filepath.Clean implementation. - {dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")}, - {dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")}, - {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")}, - } { - got, err := SecureJoin(test.root, test.unsafe) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue - } - if got != test.expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) - continue - } - } -} - -// Make sure that symlink loops result in errors. -func TestSymlinkLoop(t *testing.T) { - dir := expandedTempDir(t) - - mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) - symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "subdir", "link")) - symlink(t, "/subdir/link", filepath.Join(dir, "path")) - symlink(t, "/../../../../../../../../../../../../../../../../self", filepath.Join(dir, "self")) - - for _, test := range []struct { - root, unsafe string - }{ - {dir, "subdir/link"}, - {dir, "path"}, - {dir, "../../path"}, - {dir, "subdir/link/../.."}, - {dir, "../../../../../../../../../../../../../../../../subdir/link/../../../../../../../../../../../../../../../.."}, - {dir, "self"}, - {dir, "self/.."}, - {dir, "/../../../../../../../../../../../../../../../../self/.."}, - {dir, "/self/././.."}, - } { - got, err := SecureJoin(test.root, test.unsafe) - if !errors.Is(err, syscall.ELOOP) { - t.Errorf("securejoin(%q, %q): expected ELOOP, got %q & %v", test.root, test.unsafe, got, err) - continue - } - } -} - -// Make sure that ENOTDIR is correctly handled. -func TestEnotdir(t *testing.T) { - dir := expandedTempDir(t) - - mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) - writeFile(t, filepath.Join(dir, "notdir"), []byte("I am not a directory!"), 0o755) - symlink(t, "/../../../notdir/somechild", filepath.Join(dir, "subdir", "link")) - - for _, test := range []struct { - root, unsafe string - }{ - {dir, "subdir/link"}, - {dir, "notdir"}, - {dir, "notdir/child"}, - } { - _, err := SecureJoin(test.root, test.unsafe) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue - } - } -} - // Some silly tests to make sure that all error types are correctly handled. func TestIsNotExist(t *testing.T) { for _, test := range []struct { @@ -222,121 +80,6 @@ func TestIsNotExist(t *testing.T) { } } -type mockVFS struct { - lstat func(path string) (os.FileInfo, error) - readlink func(path string) (string, error) -} - -func (m mockVFS) Lstat(path string) (os.FileInfo, error) { return m.lstat(path) } -func (m mockVFS) Readlink(path string) (string, error) { return m.readlink(path) } - -// Make sure that SecureJoinVFS actually does use the given VFS interface. -func TestSecureJoinVFS(t *testing.T) { - dir := expandedTempDir(t) - - mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) - mkdirAll(t, filepath.Join(dir, "cousinparent", "cousin"), 0o755) - symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link")) - symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2")) - symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3")) - - for _, test := range []input{ - {dir, "subdir", filepath.Join(dir, "subdir")}, - {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, - {dir, "subdir/../test", filepath.Join(dir, "test")}, - // This is the divergence from a simple filepath.Clean implementation. - {dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")}, - {dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")}, - {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")}, - } { - var nLstat, nReadlink int - mock := mockVFS{ - lstat: func(path string) (os.FileInfo, error) { nLstat++; return os.Lstat(path) }, - readlink: func(path string) (string, error) { nReadlink++; return os.Readlink(path) }, - } - - got, err := SecureJoinVFS(test.root, test.unsafe, mock) - if err != nil { - t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue - } - if got != test.expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) - continue - } - if nLstat == 0 && nReadlink == 0 { - t.Errorf("securejoin(%q, %q): expected to use either lstat or readlink, neither were used", test.root, test.unsafe) - } - } -} - -// Make sure that SecureJoinVFS actually does use the given VFS interface, and -// that errors are correctly propagated. -func TestSecureJoinVFSErrors(t *testing.T) { - var ( - lstatErr = errors.New("lstat error") - readlinkErr = errors.New("readlink err") - ) - - dir := expandedTempDir(t) - - // Make a link. - symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "link")) - - // Define some fake mock functions. - lstatFailFn := func(string) (os.FileInfo, error) { return nil, lstatErr } - readlinkFailFn := func(string) (string, error) { return "", readlinkErr } - - // Make sure that the set of {lstat, readlink} failures do propagate. - for idx, test := range []struct { - vfs VFS - expected []error - }{ - { - expected: []error{nil}, - vfs: mockVFS{ - lstat: os.Lstat, - readlink: os.Readlink, - }, - }, - { - expected: []error{lstatErr}, - vfs: mockVFS{ - lstat: lstatFailFn, - readlink: os.Readlink, - }, - }, - { - expected: []error{readlinkErr}, - vfs: mockVFS{ - lstat: os.Lstat, - readlink: readlinkFailFn, - }, - }, - { - expected: []error{lstatErr, readlinkErr}, - vfs: mockVFS{ - lstat: lstatFailFn, - readlink: readlinkFailFn, - }, - }, - } { - _, err := SecureJoinVFS(dir, "link", test.vfs) - - success := false - for _, exp := range test.expected { - if errors.Is(err, exp) { - success = true - } - } - if !success { - t.Errorf("SecureJoinVFS.mock%d: expected to get lstatError, got %v", idx, err) - } - } -} - func TestUncleanRoot(t *testing.T) { root := t.TempDir() diff --git a/symlink_test.go b/symlink_test.go new file mode 100644 index 0000000..9739639 --- /dev/null +++ b/symlink_test.go @@ -0,0 +1,276 @@ +// Copyright (C) 2017-2025 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !plan9 + +package securejoin + +import ( + "errors" + "os" + "path/filepath" + "syscall" + "testing" + + "github.com/stretchr/testify/require" +) + +// TODO: These tests won't work on plan9 because it doesn't have symlinks, and +// also we use '/' here explicitly which probably won't work on Windows. + +func expandedTempDir(t *testing.T) string { + dir := t.TempDir() + dir, err := filepath.EvalSymlinks(dir) + require.NoError(t, err) + return dir +} + +// Test basic handling of symlink expansion. +func TestSymlink(t *testing.T) { + dir := expandedTempDir(t) + + symlink(t, "somepath", filepath.Join(dir, "etc")) + symlink(t, "../../../../../../../../../../../../../etc", filepath.Join(dir, "etclink")) + symlink(t, "/../../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "passwd")) + + rootOrVol := string(filepath.Separator) + if vol := filepath.VolumeName(dir); vol != "" { + rootOrVol = vol + rootOrVol + } + + tc := []input{ + // Make sure that expansion with a root of '/' proceeds in the expected fashion. + {rootOrVol, filepath.Join(dir, "passwd"), filepath.Join(rootOrVol, "etc", "passwd")}, + {rootOrVol, filepath.Join(dir, "etclink"), filepath.Join(rootOrVol, "etc")}, + + {rootOrVol, filepath.Join(dir, "etc"), filepath.Join(dir, "somepath")}, + // Now test scoped expansion. + {dir, "passwd", filepath.Join(dir, "somepath", "passwd")}, + {dir, "etclink", filepath.Join(dir, "somepath")}, + {dir, "etc", filepath.Join(dir, "somepath")}, + {dir, "etc/test", filepath.Join(dir, "somepath", "test")}, + {dir, "etc/test/..", filepath.Join(dir, "somepath")}, + } + + for _, test := range tc { + got, err := SecureJoin(test.root, test.unsafe) + if err != nil { + t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) + continue + } + // This is only for OS X, where /etc is a symlink to /private/etc. In + // principle, SecureJoin(/, pth) is the same as EvalSymlinks(pth) in + // the case where the path exists. + if test.root == "/" { + if expected, err := filepath.EvalSymlinks(test.expected); err == nil { + test.expected = expected + } + } + if got != test.expected { + t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) + continue + } + } +} + +// Make sure that .. is **not** expanded lexically. +func TestNonLexical(t *testing.T) { + dir := expandedTempDir(t) + + mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) + mkdirAll(t, filepath.Join(dir, "cousinparent", "cousin"), 0o755) + symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link")) + symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2")) + symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3")) + + for _, test := range []input{ + {dir, "subdir", filepath.Join(dir, "subdir")}, + {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/../test", filepath.Join(dir, "test")}, + // This is the divergence from a simple filepath.Clean implementation. + {dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")}, + {dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")}, + {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")}, + } { + got, err := SecureJoin(test.root, test.unsafe) + if err != nil { + t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) + continue + } + if got != test.expected { + t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) + continue + } + } +} + +// Make sure that symlink loops result in errors. +func TestSymlinkLoop(t *testing.T) { + dir := expandedTempDir(t) + + mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) + symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "subdir", "link")) + symlink(t, "/subdir/link", filepath.Join(dir, "path")) + symlink(t, "/../../../../../../../../../../../../../../../../self", filepath.Join(dir, "self")) + + for _, test := range []struct { + root, unsafe string + }{ + {dir, "subdir/link"}, + {dir, "path"}, + {dir, "../../path"}, + {dir, "subdir/link/../.."}, + {dir, "../../../../../../../../../../../../../../../../subdir/link/../../../../../../../../../../../../../../../.."}, + {dir, "self"}, + {dir, "self/.."}, + {dir, "/../../../../../../../../../../../../../../../../self/.."}, + {dir, "/self/././.."}, + } { + got, err := SecureJoin(test.root, test.unsafe) + if !errors.Is(err, syscall.ELOOP) { + t.Errorf("securejoin(%q, %q): expected ELOOP, got %q & %v", test.root, test.unsafe, got, err) + continue + } + } +} + +// Make sure that ENOTDIR is correctly handled. +func TestEnotdir(t *testing.T) { + dir := expandedTempDir(t) + + mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) + writeFile(t, filepath.Join(dir, "notdir"), []byte("I am not a directory!"), 0o755) + symlink(t, "/../../../notdir/somechild", filepath.Join(dir, "subdir", "link")) + + for _, test := range []struct { + root, unsafe string + }{ + {dir, "subdir/link"}, + {dir, "notdir"}, + {dir, "notdir/child"}, + } { + _, err := SecureJoin(test.root, test.unsafe) + if err != nil { + t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) + continue + } + } +} + +type mockVFS struct { + lstat func(path string) (os.FileInfo, error) + readlink func(path string) (string, error) +} + +func (m mockVFS) Lstat(path string) (os.FileInfo, error) { return m.lstat(path) } +func (m mockVFS) Readlink(path string) (string, error) { return m.readlink(path) } + +// Make sure that SecureJoinVFS actually does use the given VFS interface. +func TestSecureJoinVFS(t *testing.T) { + dir := expandedTempDir(t) + + mkdirAll(t, filepath.Join(dir, "subdir"), 0o755) + mkdirAll(t, filepath.Join(dir, "cousinparent", "cousin"), 0o755) + symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link")) + symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2")) + symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3")) + + for _, test := range []input{ + {dir, "subdir", filepath.Join(dir, "subdir")}, + {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, + {dir, "subdir/../test", filepath.Join(dir, "test")}, + // This is the divergence from a simple filepath.Clean implementation. + {dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")}, + {dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")}, + {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")}, + } { + var nLstat, nReadlink int + mock := mockVFS{ + lstat: func(path string) (os.FileInfo, error) { nLstat++; return os.Lstat(path) }, + readlink: func(path string) (string, error) { nReadlink++; return os.Readlink(path) }, + } + + got, err := SecureJoinVFS(test.root, test.unsafe, mock) + if err != nil { + t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) + continue + } + if got != test.expected { + t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) + continue + } + if nLstat == 0 && nReadlink == 0 { + t.Errorf("securejoin(%q, %q): expected to use either lstat or readlink, neither were used", test.root, test.unsafe) + } + } +} + +// Make sure that SecureJoinVFS actually does use the given VFS interface, and +// that errors are correctly propagated. +func TestSecureJoinVFSErrors(t *testing.T) { + var ( + lstatErr = errors.New("lstat error") + readlinkErr = errors.New("readlink err") + ) + + dir := expandedTempDir(t) + + // Make a link. + symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "link")) + + // Define some fake mock functions. + lstatFailFn := func(string) (os.FileInfo, error) { return nil, lstatErr } + readlinkFailFn := func(string) (string, error) { return "", readlinkErr } + + // Make sure that the set of {lstat, readlink} failures do propagate. + for idx, test := range []struct { + vfs VFS + expected []error + }{ + { + expected: []error{nil}, + vfs: mockVFS{ + lstat: os.Lstat, + readlink: os.Readlink, + }, + }, + { + expected: []error{lstatErr}, + vfs: mockVFS{ + lstat: lstatFailFn, + readlink: os.Readlink, + }, + }, + { + expected: []error{readlinkErr}, + vfs: mockVFS{ + lstat: os.Lstat, + readlink: readlinkFailFn, + }, + }, + { + expected: []error{lstatErr, readlinkErr}, + vfs: mockVFS{ + lstat: lstatFailFn, + readlink: readlinkFailFn, + }, + }, + } { + _, err := SecureJoinVFS(dir, "link", test.vfs) + + success := false + for _, exp := range test.expected { + if errors.Is(err, exp) { + success = true + } + } + if !success { + t.Errorf("SecureJoinVFS.mock%d: expected to get lstatError, got %v", idx, err) + } + } +}