Skip to content

Fix WSH copy internal #1906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 6, 2025
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 108 additions & 65 deletions pkg/wshrpc/wshremote/wshremote.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,113 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C
if err != nil {
return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err)
}

copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) {
destinfo, err = os.Stat(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
}

if destinfo != nil {
if destinfo.IsDir() {
if !finfo.IsDir() {
// try to create file in directory
path = filepath.Join(path, filepath.Base(finfo.Name()))
newdestinfo, err := os.Stat(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
}
if newdestinfo != nil && !overwrite {
return 0, fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", path)
}
} else if !merge && !overwrite {
return 0, fmt.Errorf("cannot create directory %q, directory exists at path, neither overwrite nor merge specified", path)
} else if overwrite {
err := os.RemoveAll(path)
if err != nil {
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
}
}
} else {
if finfo.IsDir() {
if !overwrite {
return 0, fmt.Errorf("cannot create file %q, directory exists at path, overwrite not specified", path)
} else {
err := os.RemoveAll(path)
if err != nil {
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
}
}
} else if !overwrite {
return 0, fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", path)
}
}
}

if finfo.IsDir() {
err := os.MkdirAll(path, finfo.Mode())
if err != nil {
return 0, fmt.Errorf("cannot create directory %q: %w", path, err)
}
} else {
err := os.MkdirAll(filepath.Dir(path), 0755)
if err != nil {
return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err)
}
}

file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode())
if err != nil {
return 0, fmt.Errorf("cannot create new file %q: %w", path, err)
}
_, err = io.Copy(file, srcFile)
if err != nil {
return 0, fmt.Errorf("cannot write file %q: %w", path, err)
}
file.Close()

return finfo.Size(), nil
}

if srcConn.Host == destConn.Host {
srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path))
err := os.Rename(srcPathCleaned, destPathCleaned)

srcFileStat, err := os.Stat(srcPathCleaned)
if err != nil {
return fmt.Errorf("cannot copy file %q to %q: %w", srcPathCleaned, destPathCleaned, err)
return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err)
}

if srcFileStat.IsDir() {
err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
}
path = filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathCleaned))

var file *os.File
if !info.IsDir() {
file, err = os.Open(path)
if err != nil {
return fmt.Errorf("cannot open file %q: %w", path, err)
}
defer file.Close()
}
_, err = copyFileFunc(path, info, file)
return err
})
if err != nil {
return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
}
} else {
file, err := os.Open(srcPathCleaned)
defer file.Close()
if err != nil {
return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err)
}
_, err = copyFileFunc(destPathCleaned, srcFileStat, file)
if err != nil {
return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
}
}
} else {
timeout := DefaultTimeout
Expand All @@ -376,70 +478,11 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C
}
numFiles++
finfo := next.FileInfo()
nextPath := filepath.Join(destPathCleaned, next.Name)
destinfo, err = os.Stat(nextPath)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("cannot stat file %q: %w", nextPath, err)
}
if !finfo.IsDir() {
totalBytes += finfo.Size()
}

if destinfo != nil {
if destinfo.IsDir() {
if !finfo.IsDir() {
if !overwrite {
return fmt.Errorf("cannot create directory %q, file exists at path, overwrite not specified", nextPath)
} else {
err := os.Remove(nextPath)
if err != nil {
return fmt.Errorf("cannot remove file %q: %w", nextPath, err)
}
}
} else if !merge && !overwrite {
return fmt.Errorf("cannot create directory %q, directory exists at path, neither overwrite nor merge specified", nextPath)
} else if overwrite {
err := os.RemoveAll(nextPath)
if err != nil {
return fmt.Errorf("cannot remove directory %q: %w", nextPath, err)
}
}
} else {
if finfo.IsDir() {
if !overwrite {
return fmt.Errorf("cannot create file %q, directory exists at path, overwrite not specified", nextPath)
} else {
err := os.RemoveAll(nextPath)
if err != nil {
return fmt.Errorf("cannot remove directory %q: %w", nextPath, err)
}
}
} else if !overwrite {
return fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", nextPath)
}
}
} else {
if finfo.IsDir() {
err := os.MkdirAll(nextPath, finfo.Mode())
if err != nil {
return fmt.Errorf("cannot create directory %q: %w", nextPath, err)
}
} else {
err := os.MkdirAll(filepath.Dir(nextPath), 0755)
if err != nil {
return fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(nextPath), err)
}
file, err := os.OpenFile(nextPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode())
if err != nil {
return fmt.Errorf("cannot create new file %q: %w", nextPath, err)
}
_, err = io.Copy(file, reader)
if err != nil {
return fmt.Errorf("cannot write file %q: %w", nextPath, err)
}
file.Close()
}
n, err := copyFileFunc(filepath.Join(destPathCleaned, next.Name), finfo, reader)
if err != nil {
return fmt.Errorf("cannot copy file %q: %w", next.Name, err)
}
totalBytes += n
return nil
})
if err != nil {
Expand Down
Loading