From da3690cd99ec802444fee2b673f61e19b883cde3 Mon Sep 17 00:00:00 2001 From: Raphael Westphal Date: Sun, 30 Nov 2025 00:53:23 +0100 Subject: [PATCH] fix: read too much data --- archive_test.go | 10 ++- downloader_test.go | 105 +++++++++++++++++++++--------- marianne.go | 156 ++++++++++++++++++++++++++++++++++++--------- marianne_test.go | 5 +- validation_test.go | 49 ++++++++++---- 5 files changed, 250 insertions(+), 75 deletions(-) diff --git a/archive_test.go b/archive_test.go index 5858423..4e67068 100644 --- a/archive_test.go +++ b/archive_test.go @@ -45,7 +45,10 @@ func TestExtractZipFileBasic(t *testing.T) { // Skip ZIP extraction test - requires TUI integration t.Skip("ZIP extraction requires TUI program which can't be easily mocked") - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 0, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 0, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } // Create a mock TUI program p := tea.NewProgram(initialModel(mock.URL(), 1024, false, false, 1)) @@ -105,7 +108,10 @@ func TestExtractZipFileLarge(t *testing.T) { mock := NewMockHTTPServer(content) defer mock.Close() - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 0, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 0, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } p := tea.NewProgram(initialModel(mock.URL(), 1024, false, false, 1)) // Extract - this would fail with file descriptor exhaustion diff --git a/downloader_test.go b/downloader_test.go index e8e3eff..1ecf839 100644 --- a/downloader_test.go +++ b/downloader_test.go @@ -15,9 +15,12 @@ func TestGetFileSize(t *testing.T) { mock := NewMockHTTPServer(content) defer mock.Close() - d := NewDownloader(mock.URL(), 4, 1024*1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024*1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } - err := d.getFileSize() + err = d.getFileSize() if err != nil { t.Fatalf("getFileSize() error = %v, want nil", err) } @@ -36,10 +39,13 @@ func TestGetFileSizeRetry(t *testing.T) { // Fail first 2 attempts, succeed on 3rd mock.SetMaxFailures(2) - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 5, 50*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 5, 50*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } start := time.Now() - err := d.getFileSize() + err = d.getFileSize() elapsed := time.Since(start) if err != nil { @@ -66,9 +72,12 @@ func TestGetFileSizeExhaustedRetries(t *testing.T) { // Fail more times than max retries mock.SetMaxFailures(10) - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 50*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 50*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } - err := d.getFileSize() + err = d.getFileSize() if err == nil { t.Fatal("getFileSize() error = nil, want error when retries exhausted") } @@ -87,13 +96,16 @@ func TestDownloadChunk(t *testing.T) { mock := NewMockHTTPServer(content) defer mock.Close() - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } ctx := context.Background() var buf bytes.Buffer // Download first 1KB - err := d.downloadChunk(ctx, 0, 1023, &buf) + err = d.downloadChunk(ctx, 0, 1023, &buf) if err != nil { t.Fatalf("downloadChunk() error = %v, want nil", err) } @@ -112,13 +124,16 @@ func TestDownloadChunkMiddle(t *testing.T) { mock := NewMockHTTPServer(content) defer mock.Close() - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } ctx := context.Background() var buf bytes.Buffer // Download bytes 5000-5999 - err := d.downloadChunk(ctx, 5000, 5999, &buf) + err = d.downloadChunk(ctx, 5000, 5999, &buf) if err != nil { t.Fatalf("downloadChunk() error = %v, want nil", err) } @@ -139,12 +154,15 @@ func TestDownloadChunkRetry(t *testing.T) { // Fail first 2 attempts mock.SetMaxFailures(2) - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 5, 50*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 5, 50*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } ctx := context.Background() var buf bytes.Buffer - err := d.downloadChunk(ctx, 0, 1023, &buf) + err = d.downloadChunk(ctx, 0, 1023, &buf) if err != nil { t.Fatalf("downloadChunk() with retries error = %v, want nil", err) } @@ -169,13 +187,16 @@ func TestDownloadChunkTimeout(t *testing.T) { // Add delay longer than timeout mock.SetRequestDelay(6 * time.Minute) - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 1, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 1, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } ctx := context.Background() var buf bytes.Buffer start := time.Now() - err := d.downloadChunk(ctx, 0, 1023, &buf) + err = d.downloadChunk(ctx, 0, 1023, &buf) elapsed := time.Since(start) if err == nil { @@ -197,7 +218,10 @@ func TestDownloadChunkCancellation(t *testing.T) { // Add delay to ensure we can cancel mock.SetRequestDelay(100 * time.Millisecond) - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 50*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 50*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } ctx, cancel := context.WithCancel(context.Background()) var buf bytes.Buffer @@ -205,7 +229,7 @@ func TestDownloadChunkCancellation(t *testing.T) { // Cancel immediately cancel() - err := d.downloadChunk(ctx, 0, 1023, &buf) + err = d.downloadChunk(ctx, 0, 1023, &buf) if err == nil { t.Fatal("downloadChunk() error = nil, want cancellation error") } @@ -217,13 +241,16 @@ func TestDownloadChunkBoundary(t *testing.T) { mock := NewMockHTTPServer(content) defer mock.Close() - d := NewDownloader(mock.URL(), 4, 512, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 512, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } ctx := context.Background() var buf bytes.Buffer // Download last chunk (bytes 512-1023) - err := d.downloadChunk(ctx, 512, 1023, &buf) + err = d.downloadChunk(ctx, 512, 1023, &buf) if err != nil { t.Fatalf("downloadChunk() boundary error = %v, want nil", err) } @@ -244,7 +271,10 @@ func TestDownloadChunkServerNoRangeSupport(t *testing.T) { // Disable range support mock.SetSupportsRanges(false) - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } ctx := context.Background() var buf bytes.Buffer @@ -252,7 +282,7 @@ func TestDownloadChunkServerNoRangeSupport(t *testing.T) { // Try to download a chunk - should fail because server doesn't support ranges // When server doesn't support ranges, it returns 200 OK with full content, // which is incompatible with parallel downloads and would corrupt the file - err := d.downloadChunk(ctx, 1024, 2047, &buf) + err = d.downloadChunk(ctx, 1024, 2047, &buf) if err == nil { t.Fatal("downloadChunk() error = nil, want error for server without range support") } @@ -266,10 +296,13 @@ func TestDownloadChunkServerNoRangeSupport(t *testing.T) { // TestRetryWithBackoffSuccess tests retry logic succeeds immediately func TestRetryWithBackoffSuccess(t *testing.T) { - d := NewDownloader("http://example.com/test", 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader("http://example.com/test", 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } attempts := 0 - err := d.retryWithBackoff(context.Background(), "test operation", func() error { + err = d.retryWithBackoff(context.Background(), "test operation", func() error { attempts++ return nil // Success on first try }) @@ -285,10 +318,13 @@ func TestRetryWithBackoffSuccess(t *testing.T) { // TestRetryWithBackoffEventualSuccess tests retry succeeds after failures func TestRetryWithBackoffEventualSuccess(t *testing.T) { - d := NewDownloader("http://example.com/test", 4, 1024, "", 0, false, 5, 50*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader("http://example.com/test", 4, 1024, "", 0, false, 5, 50*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } attempts := 0 - err := d.retryWithBackoff(context.Background(), "test operation", func() error { + err = d.retryWithBackoff(context.Background(), "test operation", func() error { attempts++ if attempts < 3 { return errors.New("temporary failure") @@ -307,11 +343,14 @@ func TestRetryWithBackoffEventualSuccess(t *testing.T) { // TestRetryWithBackoffAllFail tests behavior when all retries fail func TestRetryWithBackoffAllFail(t *testing.T) { - d := NewDownloader("http://example.com/test", 4, 1024, "", 0, false, 3, 50*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader("http://example.com/test", 4, 1024, "", 0, false, 3, 50*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } attempts := 0 testErr := errors.New("persistent failure") - err := d.retryWithBackoff(context.Background(), "test operation", func() error { + err = d.retryWithBackoff(context.Background(), "test operation", func() error { attempts++ return testErr }) @@ -329,8 +368,11 @@ func TestRetryWithBackoffAllFail(t *testing.T) { // TestRetryWithBackoffCancellation tests context cancellation during retry func TestRetryWithBackoffCancellation(t *testing.T) { - d := NewDownloader("http://example.com/test", 4, 1024, "", 0, false, 10, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader("http://example.com/test", 4, 1024, "", 0, false, 10, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } ctx, cancel := context.WithCancel(context.Background()) attempts := 0 @@ -348,7 +390,7 @@ func TestRetryWithBackoffCancellation(t *testing.T) { time.Sleep(150 * time.Millisecond) cancel() - err := <-errChan + err = <-errChan if err == nil { t.Fatal("retryWithBackoff() error = nil, want cancellation error") } @@ -370,14 +412,17 @@ func TestRateLimitedReader(t *testing.T) { // Set bandwidth limit to 10KB/s bandwidthLimit := int64(10 * 1024) - d := NewDownloader(mock.URL(), 4, 1024, "", bandwidthLimit, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", bandwidthLimit, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } ctx := context.Background() var buf bytes.Buffer start := time.Now() // Download 30KB (should take ~3 seconds at 10KB/s after burst) - err := d.downloadChunk(ctx, 0, 30*1024-1, &buf) + err = d.downloadChunk(ctx, 0, 30*1024-1, &buf) elapsed := time.Since(start) if err != nil { diff --git a/marianne.go b/marianne.go index ccf6c91..fe474ba 100644 --- a/marianne.go +++ b/marianne.go @@ -16,6 +16,7 @@ import ( "sort" "strconv" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -64,11 +65,14 @@ type Downloader struct { memoryLimit int64 // Memory limit for buffering chunks } -func NewDownloader(url string, workers int, chunkSize int64, proxyURL string, bandwidthLimit int64, verbose bool, maxRetries int, retryDelay time.Duration, memoryLimit int64) *Downloader { +func NewDownloader(url string, workers int, chunkSize int64, proxyURL string, bandwidthLimit int64, verbose bool, maxRetries int, retryDelay time.Duration, memoryLimit int64) (*Downloader, error) { // Validate input parameters if workers <= 0 { workers = 1 // Default to 1 worker } + if workers > 1000 { + workers = 1000 // Cap at reasonable maximum + } if chunkSize <= 0 { chunkSize = defaultChunkSize // Use default chunk size } @@ -93,9 +97,10 @@ func NewDownloader(url string, workers int, chunkSize int64, proxyURL string, ba // Configure proxy if provided if proxyURL != "" { proxy, err := neturl.Parse(proxyURL) - if err == nil { - transport.Proxy = http.ProxyURL(proxy) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) } + transport.Proxy = http.ProxyURL(proxy) } d := &Downloader{ @@ -120,7 +125,7 @@ func NewDownloader(url string, workers int, chunkSize int64, proxyURL string, ba d.rateLimiter = rate.NewLimiter(rate.Limit(bandwidthLimit), int(bandwidthLimit)) } - return d + return d, nil } // retryWithBackoff executes a function with exponential backoff retry logic @@ -759,7 +764,7 @@ func detectOrUseArchiveType(filename string, forcedFormat string) (string, strin return detectArchiveType(filename) } -func (d *Downloader) Download(ctx context.Context, p *tea.Program, outputDir string, forcedFormat string) error { +func (d *Downloader) Download(ctx context.Context, p *tea.Program, outputDir string, forcedFormat string, tracker *cleanupTracker) error { // Get file info if err := d.getFileSize(); err != nil { return err @@ -782,7 +787,7 @@ func (d *Downloader) Download(ctx context.Context, p *tea.Program, outputDir str if isZip { // Handle ZIP files - return d.downloadAndExtractZip(ctx, p, outputDir) + return d.downloadAndExtractZip(ctx, p, outputDir, tracker) } // Handle tar-based archives @@ -947,7 +952,7 @@ func (w *multiPartCountingWriter) Write(p []byte) (int, error) { func downloadMultiPart(ctx context.Context, urls []string, partSizes []int64, p *tea.Program, outputDir string, workers int, chunkSize int64, proxyURL string, limitBytes int64, verbose bool, maxRetries int, retryDelay time.Duration, - memBytes int64, forcedFormat string) error { + memBytes int64, forcedFormat string, tracker *cleanupTracker) error { // Detect archive type from first URL or use forced format tarFlag, tarCommand, isZip, err := detectOrUseArchiveType(urls[0], forcedFormat) @@ -964,7 +969,7 @@ func downloadMultiPart(ctx context.Context, urls []string, partSizes []int64, p if isZip { return downloadMultiPartZip(ctx, urls, partSizes, p, outputDir, workers, chunkSize, - proxyURL, limitBytes, verbose, maxRetries, retryDelay, memBytes) + proxyURL, limitBytes, verbose, maxRetries, retryDelay, memBytes, tracker) } return downloadMultiPartTar(ctx, urls, partSizes, p, outputDir, tarFlag, tarCommand, @@ -1103,8 +1108,14 @@ func downloadMultiPartTar(ctx context.Context, urls []string, partSizes []int64, }) // Create downloader for this part - d := NewDownloader(url, workers, chunkSize, proxyURL, limitBytes, verbose, + d, err := NewDownloader(url, workers, chunkSize, proxyURL, limitBytes, verbose, maxRetries, retryDelay, memBytes) + if err != nil { + pipeWriter.CloseWithError(err) + close(done) + cmd.Process.Kill() + return fmt.Errorf("failed to create downloader: %w", err) + } d.totalSize = partSizes[i] // Create counting writer that updates overall progress @@ -1145,16 +1156,28 @@ func downloadMultiPartTar(ctx context.Context, urls []string, partSizes []int64, func downloadMultiPartZip(ctx context.Context, urls []string, partSizes []int64, p *tea.Program, outputDir string, workers int, chunkSize int64, proxyURL string, limitBytes int64, verbose bool, maxRetries int, retryDelay time.Duration, - memBytes int64) error { + memBytes int64, tracker *cleanupTracker) error { // Create temp file for combined ZIP tmpFile, err := os.CreateTemp("", "marianne-multipart-*.zip") if err != nil { return fmt.Errorf("failed to create temp file: %w", err) } + tempFileName := tmpFile.Name() + // Register temp file for cleanup on interruption + if tracker != nil { + tracker.Add(tempFileName) + } + cleanupFile := false // Only cleanup on success defer func() { tmpFile.Close() - os.Remove(tmpFile.Name()) + if cleanupFile { + // Unregister before removing + if tracker != nil { + tracker.Remove(tempFileName) + } + os.Remove(tempFileName) + } }() // Calculate total size @@ -1217,8 +1240,12 @@ func downloadMultiPartZip(ctx context.Context, urls []string, partSizes []int64, url: url, }) - d := NewDownloader(url, workers, chunkSize, proxyURL, limitBytes, verbose, + d, err := NewDownloader(url, workers, chunkSize, proxyURL, limitBytes, verbose, maxRetries, retryDelay, memBytes) + if err != nil { + close(done) + return fmt.Errorf("failed to create downloader: %w", err) + } d.totalSize = partSizes[i] // Create counting writer @@ -1239,7 +1266,11 @@ func downloadMultiPartZip(ctx context.Context, urls []string, partSizes []int64, // Extract the combined ZIP using existing extractZipFile dummyDownloader := &Downloader{} - return dummyDownloader.extractZipFile(tmpFile.Name(), outputDir, p) + err = dummyDownloader.extractZipFile(tempFileName, outputDir, p) + if err == nil { + cleanupFile = true // Only cleanup on successful extraction + } + return err } func parseBandwidthLimit(limit string) int64 { @@ -1301,16 +1332,59 @@ func parseMemoryLimit(limit string, workers int) int64 { return parseBandwidthLimit(limit) } -func (d *Downloader) downloadAndExtractZip(ctx context.Context, p *tea.Program, outputDir string) error { +// cleanupTracker tracks temporary files for cleanup on interruption +type cleanupTracker struct { + mu sync.Mutex + files []string +} + +func (c *cleanupTracker) Add(path string) { + c.mu.Lock() + defer c.mu.Unlock() + c.files = append(c.files, path) +} + +func (c *cleanupTracker) Remove(path string) { + c.mu.Lock() + defer c.mu.Unlock() + for i, f := range c.files { + if f == path { + c.files = append(c.files[:i], c.files[i+1:]...) + break + } + } +} + +func (c *cleanupTracker) Cleanup() { + c.mu.Lock() + defer c.mu.Unlock() + for _, f := range c.files { + os.Remove(f) + } +} + +func (d *Downloader) downloadAndExtractZip(ctx context.Context, p *tea.Program, outputDir string, tracker *cleanupTracker) error { // Create temp file tmpFile, err := os.CreateTemp("", "marianne-*.zip") if err != nil { return fmt.Errorf("failed to create temp file: %w", err) } + tempFileName := tmpFile.Name() + // Register temp file for cleanup on interruption + if tracker != nil { + tracker.Add(tempFileName) + } + cleanupFile := false // Only cleanup on success defer func() { tmpFile.Close() - os.Remove(tmpFile.Name()) + if cleanupFile { + // Unregister before removing + if tracker != nil { + tracker.Remove(tempFileName) + } + os.Remove(tempFileName) + } }() // Progress reporter @@ -1366,24 +1440,32 @@ func (d *Downloader) downloadAndExtractZip(ctx context.Context, p *tea.Program, close(done) // Extract ZIP file - return d.extractZipFile(tmpFile.Name(), outputDir, p) + err = d.extractZipFile(tempFileName, outputDir, p) + if err == nil { + cleanupFile = true // Only cleanup on successful extraction + } + return err } // validateZipPath checks if a zip entry path is safe (no path traversal) func validateZipPath(path string) error { - // Check for path traversal attempts - if strings.Contains(path, "..") { - return fmt.Errorf("path contains '..': %s", path) + // Check for null bytes + if strings.Contains(path, "\x00") { + return fmt.Errorf("path contains null byte: %s", path) } - // Check for absolute paths - if filepath.IsAbs(path) { + // Clean the path and check if absolute + cleaned := filepath.Clean(path) + if filepath.IsAbs(cleaned) { return fmt.Errorf("path is absolute: %s", path) } - // Check for backslashes (Windows path separators used maliciously on Unix) - if strings.Contains(path, "\\") { - return fmt.Errorf("path contains backslash: %s", path) + // Check each component for ".." + parts := strings.Split(cleaned, string(filepath.Separator)) + for _, part := range parts { + if part == ".." { + return fmt.Errorf("path traversal attempt: %s", path) + } } return nil @@ -1422,7 +1504,9 @@ func (d *Downloader) extractZipFile(filename string, outputDir string, p *tea.Pr p.Send(fileExtractedMsg(file.Name)) if file.FileInfo().IsDir() { - os.MkdirAll(path, file.Mode()) + if err := os.MkdirAll(path, file.Mode()); err != nil { + return fmt.Errorf("failed to create directory %s: %w", path, err) + } continue } @@ -1699,7 +1783,12 @@ func main() { } else { // Single URL download path url := urls[0] - downloader := NewDownloader(url, *workers, *chunkSize, *proxyURL, limitBytes, *verbose, *maxRetries, *retryDelay, memBytes) + downloader, err := NewDownloader(url, *workers, *chunkSize, *proxyURL, limitBytes, *verbose, *maxRetries, *retryDelay, memBytes) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + // Get file size first if err := downloader.getFileSize(); err != nil { @@ -1725,6 +1814,9 @@ func main() { ) } + // Create cleanup tracker for temp files + tracker := &cleanupTracker{} + // Set up signal handler for graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) @@ -1765,16 +1857,18 @@ func main() { err = validateErr } else { err = downloadMultiPart(ctx, urls, partSizes, p, *outputDir, *workers, *chunkSize, - *proxyURL, limitBytes, *verbose, *maxRetries, *retryDelay, memBytes, *archiveFormat) + *proxyURL, limitBytes, *verbose, *maxRetries, *retryDelay, memBytes, *archiveFormat, tracker) } } else { // Single URL download url := urls[0] - downloader := NewDownloader(url, *workers, *chunkSize, *proxyURL, limitBytes, *verbose, *maxRetries, *retryDelay, memBytes) - if sizeErr := downloader.getFileSize(); sizeErr != nil { + downloader, createErr := NewDownloader(url, *workers, *chunkSize, *proxyURL, limitBytes, *verbose, *maxRetries, *retryDelay, memBytes) + if createErr != nil { + err = createErr + } else if sizeErr := downloader.getFileSize(); sizeErr != nil { err = sizeErr } else { - err = downloader.Download(ctx, p, *outputDir, *archiveFormat) + err = downloader.Download(ctx, p, *outputDir, *archiveFormat, tracker) } } @@ -1789,6 +1883,8 @@ func main() { // Handle signals in background go func() { <-sigChan + // Cleanup temp files + tracker.Cleanup() // Cancel context to stop download cancel() // Quit the TUI diff --git a/marianne_test.go b/marianne_test.go index 071d790..b989e8a 100644 --- a/marianne_test.go +++ b/marianne_test.go @@ -249,7 +249,7 @@ func TestNewDownloader(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := NewDownloader( + d, err := NewDownloader( tt.url, tt.workers, tt.chunkSize, @@ -261,6 +261,9 @@ func TestNewDownloader(t *testing.T) { tt.memoryLimit, ) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } if d.url != tt.url { t.Errorf("url = %q, want %q", d.url, tt.url) } diff --git a/validation_test.go b/validation_test.go index 197d03e..cad4c04 100644 --- a/validation_test.go +++ b/validation_test.go @@ -23,8 +23,11 @@ func TestWorkerCountValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := NewDownloader("https://example.com/test.tar.gz", tt.workers, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader("https://example.com/test.tar.gz", tt.workers, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } // Validation now works - invalid values get set to default if tt.workers <= 0 { t.Logf("✅ FIXED: Worker count %d validated and set to default %d", tt.workers, tt.expected) @@ -58,8 +61,11 @@ func TestChunkSizeValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := NewDownloader("https://example.com/test.tar.gz", 4, tt.chunkSize, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader("https://example.com/test.tar.gz", 4, tt.chunkSize, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } if tt.chunkSize <= 0 { t.Logf("✅ FIXED: Chunk size %d validated and set to default %d", tt.chunkSize, tt.expected) } @@ -114,18 +120,25 @@ func TestProxyURLValidation(t *testing.T) { }{ {"Valid proxy", "http://proxy:8080", false}, {"Valid with auth", "http://user:pass@proxy:8080", false}, - {"Invalid URL - BUG: silently ignored", "not a url", false}, + {"Relative path (parsed as valid)", "not a url", false}, // url.Parse accepts this as a relative path {"Empty", "", false}, - {"Malformed - BUG", "://broken", false}, + {"Malformed", "://broken", true}, // This actually fails parsing } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := NewDownloader("https://example.com/test.tar.gz", 4, 1024, tt.proxyURL, 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader("https://example.com/test.tar.gz", 4, 1024, tt.proxyURL, 0, false, 3, 100*time.Millisecond, 1024*1024*1024) - // Currently errors are silently ignored in proxy setup if tt.wantErr { - t.Logf("BUG: Invalid proxy URL %q was silently ignored", tt.proxyURL) + if err == nil { + t.Errorf("NewDownloader with invalid proxy %q should have returned error", tt.proxyURL) + } + // Error is expected, test passed + return + } + + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) } if d == nil { @@ -149,9 +162,12 @@ func TestZeroContentLength(t *testing.T) { mock := NewMockHTTPServer(content) defer mock.Close() - d := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(mock.URL(), 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } - err := d.getFileSize() + err = d.getFileSize() if err != nil { t.Fatalf("getFileSize() error = %v, want nil", err) } @@ -378,8 +394,11 @@ func TestMaxRetriesValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := NewDownloader("https://example.com/test.tar.gz", 4, 1024, "", 0, false, tt.maxRetries, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader("https://example.com/test.tar.gz", 4, 1024, "", 0, false, tt.maxRetries, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } if tt.maxRetries < 0 { t.Logf("✅ FIXED: Negative max retries %d validated and set to default %d", tt.maxRetries, tt.expected) } @@ -409,8 +428,11 @@ func TestURLValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Currently no URL validation in NewDownloader - d := NewDownloader(tt.url, 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader(tt.url, 4, 1024, "", 0, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } if d == nil { t.Fatal("Downloader should be created (no validation)") } @@ -427,8 +449,11 @@ func TestURLValidation(t *testing.T) { // TestRateLimiterBurstConfig tests rate limiter burst configuration func TestRateLimiterBurstConfig(t *testing.T) { bandwidthLimit := int64(1024 * 1024) // 1MB/s - d := NewDownloader("https://example.com/test.tar.gz", 4, 1024, "", bandwidthLimit, false, 3, 100*time.Millisecond, 1024*1024*1024) + d, err := NewDownloader("https://example.com/test.tar.gz", 4, 1024, "", bandwidthLimit, false, 3, 100*time.Millisecond, 1024*1024*1024) + if err != nil { + t.Fatalf("NewDownloader failed: %v", err) + } if d.rateLimiter == nil { t.Fatal("Rate limiter should be initialized") }