diff --git a/pkg/remote/connparse/connparse.go b/pkg/remote/connparse/connparse.go index 9951065617..b099d1c0a9 100644 --- a/pkg/remote/connparse/connparse.go +++ b/pkg/remote/connparse/connparse.go @@ -128,8 +128,11 @@ func ParseURI(uri string) (*Connection, error) { } } + addPrecedingSlash := true + if scheme == "" { scheme = ConnectionTypeWsh + addPrecedingSlash = false if len(rest) != len(uri) { // This accounts for when the uri starts with "//", which would get trimmed in the first split. parseWshPath() @@ -152,7 +155,7 @@ func ParseURI(uri string) (*Connection, error) { } if strings.HasPrefix(remotePath, "/~") { remotePath = strings.TrimPrefix(remotePath, "/") - } else if len(remotePath) > 1 && !windowsDriveRegex.MatchString(remotePath) && !strings.HasPrefix(remotePath, "/") && !strings.HasPrefix(remotePath, "~") && !strings.HasPrefix(remotePath, "./") && !strings.HasPrefix(remotePath, "../") && !strings.HasPrefix(remotePath, ".\\") && !strings.HasPrefix(remotePath, "..\\") && remotePath != ".." { + } else if addPrecedingSlash && (len(remotePath) > 1 && !windowsDriveRegex.MatchString(remotePath) && !strings.HasPrefix(remotePath, "/") && !strings.HasPrefix(remotePath, "~") && !strings.HasPrefix(remotePath, "./") && !strings.HasPrefix(remotePath, "../") && !strings.HasPrefix(remotePath, ".\\") && !strings.HasPrefix(remotePath, "..\\") && remotePath != "..") { remotePath = "/" + remotePath } } diff --git a/pkg/remote/connparse/connparse_test.go b/pkg/remote/connparse/connparse_test.go index c530c8e768..82ccc83625 100644 --- a/pkg/remote/connparse/connparse_test.go +++ b/pkg/remote/connparse/connparse_test.go @@ -212,6 +212,50 @@ func TestParseURI_WSHCurrentPath(t *testing.T) { if c.GetFullURI() != expected { t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) } + + cstr = "path/to/file" + c, err = connparse.ParseURI(cstr) + if err != nil { + t.Fatalf("failed to parse URI: %v", err) + } + expected = "path/to/file" + if c.Path != expected { + t.Fatalf("expected path to be %q, got %q", expected, c.Path) + } + expected = "current" + if c.Host != expected { + t.Fatalf("expected host to be %q, got %q", expected, c.Host) + } + expected = "wsh" + if c.Scheme != expected { + t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + } + expected = "wsh://current/path/to/file" + if c.GetFullURI() != expected { + t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + } + + cstr = "/etc/path/to/file" + c, err = connparse.ParseURI(cstr) + if err != nil { + t.Fatalf("failed to parse URI: %v", err) + } + expected = "/etc/path/to/file" + if c.Path != expected { + t.Fatalf("expected path to be %q, got %q", expected, c.Path) + } + expected = "current" + if c.Host != expected { + t.Fatalf("expected host to be %q, got %q", expected, c.Host) + } + expected = "wsh" + if c.Scheme != expected { + t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + } + expected = "wsh://current/etc/path/to/file" + if c.GetFullURI() != expected { + t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + } } func TestParseURI_WSHCurrentPathWindows(t *testing.T) { diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index f324f52fec..f66c8cdaa9 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -348,11 +348,113 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C if err != nil { return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err) } + + copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) { + destinfo, err = os.Stat(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return 0, fmt.Errorf("cannot stat file %q: %w", path, err) + } + + if destinfo != nil { + if destinfo.IsDir() { + if !finfo.IsDir() { + // try to create file in directory + path = filepath.Join(path, filepath.Base(finfo.Name())) + newdestinfo, err := os.Stat(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return 0, fmt.Errorf("cannot stat file %q: %w", path, err) + } + if newdestinfo != nil && !overwrite { + return 0, fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", path) + } + } else if !merge && !overwrite { + return 0, fmt.Errorf("cannot create directory %q, directory exists at path, neither overwrite nor merge specified", path) + } else if overwrite { + err := os.RemoveAll(path) + if err != nil { + return 0, fmt.Errorf("cannot remove directory %q: %w", path, err) + } + } + } else { + if finfo.IsDir() { + if !overwrite { + return 0, fmt.Errorf("cannot create file %q, directory exists at path, overwrite not specified", path) + } else { + err := os.RemoveAll(path) + if err != nil { + return 0, fmt.Errorf("cannot remove directory %q: %w", path, err) + } + } + } else if !overwrite { + return 0, fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", path) + } + } + } + + if finfo.IsDir() { + err := os.MkdirAll(path, finfo.Mode()) + if err != nil { + return 0, fmt.Errorf("cannot create directory %q: %w", path, err) + } + } else { + err := os.MkdirAll(filepath.Dir(path), 0755) + if err != nil { + return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err) + } + } + + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode()) + if err != nil { + return 0, fmt.Errorf("cannot create new file %q: %w", path, err) + } + defer file.Close() + _, err = io.Copy(file, srcFile) + if err != nil { + return 0, fmt.Errorf("cannot write file %q: %w", path, err) + } + + return finfo.Size(), nil + } + if srcConn.Host == destConn.Host { srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path)) - err := os.Rename(srcPathCleaned, destPathCleaned) + + srcFileStat, err := os.Stat(srcPathCleaned) if err != nil { - return fmt.Errorf("cannot copy file %q to %q: %w", srcPathCleaned, destPathCleaned, err) + return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err) + } + + if srcFileStat.IsDir() { + err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + srcFilePath := path + destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathCleaned)) + var file *os.File + if !info.IsDir() { + file, err = os.Open(srcFilePath) + if err != nil { + return fmt.Errorf("cannot open file %q: %w", srcFilePath, err) + } + defer file.Close() + } + _, err = copyFileFunc(destFilePath, info, file) + return err + }) + if err != nil { + return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) + } + } else { + file, err := os.Open(srcPathCleaned) + if err != nil { + return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) + } + defer file.Close() + _, err = copyFileFunc(destPathCleaned, srcFileStat, file) + if err != nil { + return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) + } } } else { timeout := DefaultTimeout @@ -376,70 +478,11 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } numFiles++ finfo := next.FileInfo() - nextPath := filepath.Join(destPathCleaned, next.Name) - destinfo, err = os.Stat(nextPath) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("cannot stat file %q: %w", nextPath, err) - } - if !finfo.IsDir() { - totalBytes += finfo.Size() - } - - if destinfo != nil { - if destinfo.IsDir() { - if !finfo.IsDir() { - if !overwrite { - return fmt.Errorf("cannot create directory %q, file exists at path, overwrite not specified", nextPath) - } else { - err := os.Remove(nextPath) - if err != nil { - return fmt.Errorf("cannot remove file %q: %w", nextPath, err) - } - } - } else if !merge && !overwrite { - return fmt.Errorf("cannot create directory %q, directory exists at path, neither overwrite nor merge specified", nextPath) - } else if overwrite { - err := os.RemoveAll(nextPath) - if err != nil { - return fmt.Errorf("cannot remove directory %q: %w", nextPath, err) - } - } - } else { - if finfo.IsDir() { - if !overwrite { - return fmt.Errorf("cannot create file %q, directory exists at path, overwrite not specified", nextPath) - } else { - err := os.RemoveAll(nextPath) - if err != nil { - return fmt.Errorf("cannot remove directory %q: %w", nextPath, err) - } - } - } else if !overwrite { - return fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", nextPath) - } - } - } else { - if finfo.IsDir() { - err := os.MkdirAll(nextPath, finfo.Mode()) - if err != nil { - return fmt.Errorf("cannot create directory %q: %w", nextPath, err) - } - } else { - err := os.MkdirAll(filepath.Dir(nextPath), 0755) - if err != nil { - return fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(nextPath), err) - } - file, err := os.OpenFile(nextPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode()) - if err != nil { - return fmt.Errorf("cannot create new file %q: %w", nextPath, err) - } - _, err = io.Copy(file, reader) - if err != nil { - return fmt.Errorf("cannot write file %q: %w", nextPath, err) - } - file.Close() - } + n, err := copyFileFunc(filepath.Join(destPathCleaned, next.Name), finfo, reader) + if err != nil { + return fmt.Errorf("cannot copy file %q: %w", next.Name, err) } + totalBytes += n return nil }) if err != nil {