diff --git a/fastdialer/utils/dialwrap.go b/fastdialer/utils/dialwrap.go index f81a7c8..1993551 100644 --- a/fastdialer/utils/dialwrap.go +++ b/fastdialer/utils/dialwrap.go @@ -153,7 +153,7 @@ func (d *DialWrap) DialContext(ctx context.Context, _ string, _ string) (net.Con err := d.err d.firstConnCond.L.Unlock() - if err != nil && !errkit.Is(err, ErrInflightCancel) && !errkit.Is(err, context.Canceled) { + if err != nil && !errkit.Is(err, ErrInflightCancel) && !errkit.Is(err, context.Canceled) && !errkit.Is(err, context.DeadlineExceeded) { return nil, err } return d.dial(ctx) @@ -274,7 +274,7 @@ func (d *DialWrap) dialAllParallel(ctx context.Context) ([]*dialResult, error) { return conns, nil } else { - if !errkit.Is(result.error, ErrInflightCancel) && !errkit.Is(result.error, context.Canceled) { + if !errkit.Is(result.error, ErrInflightCancel) && !errkit.Is(result.error, context.Canceled) && !errkit.Is(result.error, context.DeadlineExceeded) { errs = append(errs, result) } } @@ -293,7 +293,7 @@ func (d *DialWrap) dialAllParallel(ctx context.Context) ([]*dialResult, error) { } // If not inflight cancel then it is a permanent error (port closed/filtered) - if !errkit.Is(finalErr, ErrInflightCancel) { + if !errkit.Is(finalErr, ErrInflightCancel) && !errkit.Is(finalErr, context.Canceled) && !errkit.Is(finalErr, context.DeadlineExceeded) { return nil, errkit.Append(ErrPortClosedOrFiltered, finalErr) } diff --git a/fastdialer/utils/dialwrap_test.go b/fastdialer/utils/dialwrap_test.go new file mode 100644 index 0000000..f19e6cc --- /dev/null +++ b/fastdialer/utils/dialwrap_test.go @@ -0,0 +1,174 @@ +package utils + +import ( + "context" + "fmt" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/projectdiscovery/utils/errkit" +) + +// TestDialAllParallel_DeadlineExceeded verifies that when all dials fail with +// context.DeadlineExceeded the error is NOT wrapped as ErrPortClosedOrFiltered. +// This prevents poisoning the dial cache with a permanent error when the real +// cause was a transient timeout. +func TestDialAllParallel_DeadlineExceeded(t *testing.T) { + t.Parallel() + + // Use a very short deadline so all dials fail with DeadlineExceeded. + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + // Let the deadline expire before dialing. + time.Sleep(5 * time.Millisecond) + + dw, err := NewDialWrap( + &net.Dialer{Timeout: 1 * time.Millisecond}, + []string{"192.0.2.1"}, // RFC 5737 TEST-NET, will never connect + "tcp", + "192.0.2.1:12345", + "12345", + ) + if err != nil { + t.Fatal(err) + } + + _, dialErr := dw.dialAllParallel(ctx) + if dialErr == nil { + t.Fatal("expected an error from dialAllParallel with expired context, got nil") + } + + if errkit.Is(dialErr, ErrPortClosedOrFiltered) { + t.Fatalf("DeadlineExceeded must not be classified as ErrPortClosedOrFiltered, got: %v", dialErr) + } +} + +// TestDialAllParallel_ContextCanceled verifies that when all dials fail with +// context.Canceled the error is NOT wrapped as ErrPortClosedOrFiltered. +func TestDialAllParallel_ContextCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel immediately so all dials get context.Canceled. + cancel() + + dw, err := NewDialWrap( + &net.Dialer{Timeout: 1 * time.Millisecond}, + []string{"192.0.2.1"}, + "tcp", + "192.0.2.1:12345", + "12345", + ) + if err != nil { + t.Fatal(err) + } + + _, dialErr := dw.dialAllParallel(ctx) + if dialErr == nil { + t.Fatal("expected an error from dialAllParallel with canceled context, got nil") + } + + if errkit.Is(dialErr, ErrPortClosedOrFiltered) { + t.Fatalf("context.Canceled must not be classified as ErrPortClosedOrFiltered, got: %v", dialErr) + } +} + +// TestDialAllParallel_RealConnectionRefused verifies that a genuine +// connection-refused error IS still classified as ErrPortClosedOrFiltered. +func TestDialAllParallel_RealConnectionRefused(t *testing.T) { + t.Parallel() + + // Bind a listener and immediately close it to guarantee a refused port. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + port := ln.Addr().(*net.TCPAddr).Port + ln.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + dw, err := NewDialWrap( + &net.Dialer{Timeout: 2 * time.Second}, + []string{"127.0.0.1"}, + "tcp", + "127.0.0.1", + fmt.Sprintf("%d", port), + ) + if err != nil { + t.Fatal(err) + } + + _, dialErr := dw.dialAllParallel(ctx) + if dialErr == nil { + t.Fatal("expected an error from dialAllParallel to a refused port, got nil") + } + + if !errkit.Is(dialErr, ErrPortClosedOrFiltered) { + t.Fatalf("connection refused should still be classified as ErrPortClosedOrFiltered, got: %v", dialErr) + } +} + +// TestDialContext_DeadlineExceededNotCached verifies that DialContext with a +// deadline-exceeded first connection does not permanently store the error +// as ErrPortClosedOrFiltered, so a subsequent caller is not poisoned. +func TestDialContext_DeadlineExceededNotCached(t *testing.T) { + t.Parallel() + + // Start a listener that accepts but never responds (simulates slow host). + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + port := ln.Addr().(*net.TCPAddr).Port + + var accepted atomic.Int32 + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + accepted.Add(1) + go func(c net.Conn) { + time.Sleep(10 * time.Second) + c.Close() + }(conn) + } + }() + + dw, err := NewDialWrap( + &net.Dialer{Timeout: 50 * time.Millisecond}, + []string{"127.0.0.1"}, + "tcp", + "127.0.0.1", + fmt.Sprintf("%d", port), + ) + if err != nil { + t.Fatal(err) + } + + // First call: use an already-expired context. + expiredCtx, cancel1 := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel1() + time.Sleep(5 * time.Millisecond) + + _, err1 := dw.DialContext(expiredCtx, "", "") + if err1 == nil { + t.Fatal("expected error from DialContext with expired context") + } + + // The stored error must NOT be ErrPortClosedOrFiltered. + dw.firstConnCond.L.Lock() + storedErr := dw.err + dw.firstConnCond.L.Unlock() + + if storedErr != nil && errkit.Is(storedErr, ErrPortClosedOrFiltered) { + t.Fatalf("deadline-exceeded error must not be cached as ErrPortClosedOrFiltered: %v", storedErr) + } +} \ No newline at end of file