From 75518f32d46433e2d5b212a0675b7d54a9e6f345 Mon Sep 17 00:00:00 2001 From: Niek den Breeje Date: Tue, 10 Feb 2026 11:37:26 +0100 Subject: [PATCH] fix: prevent context.DeadlineExceeded from poisoning dial cache as permanent error dialAllParallel misclassified context.DeadlineExceeded (and in one path, context.Canceled) as ErrPortClosedOrFiltered. This caused transient timeouts to be cached as permanent failures, blocking all subsequent connections to that host:port for the rest of the scan. Add DeadlineExceeded and Canceled guards at all three classification points in dialwrap.go so only genuine connection-refused errors are treated as permanent. --- fastdialer/utils/dialwrap.go | 6 +- fastdialer/utils/dialwrap_test.go | 174 ++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+), 3 deletions(-) create mode 100644 fastdialer/utils/dialwrap_test.go diff --git a/fastdialer/utils/dialwrap.go b/fastdialer/utils/dialwrap.go index f81a7c8..1993551 100644 --- a/fastdialer/utils/dialwrap.go +++ b/fastdialer/utils/dialwrap.go @@ -153,7 +153,7 @@ func (d *DialWrap) DialContext(ctx context.Context, _ string, _ string) (net.Con err := d.err d.firstConnCond.L.Unlock() - if err != nil && !errkit.Is(err, ErrInflightCancel) && !errkit.Is(err, context.Canceled) { + if err != nil && !errkit.Is(err, ErrInflightCancel) && !errkit.Is(err, context.Canceled) && !errkit.Is(err, context.DeadlineExceeded) { return nil, err } return d.dial(ctx) @@ -274,7 +274,7 @@ func (d *DialWrap) dialAllParallel(ctx context.Context) ([]*dialResult, error) { return conns, nil } else { - if !errkit.Is(result.error, ErrInflightCancel) && !errkit.Is(result.error, context.Canceled) { + if !errkit.Is(result.error, ErrInflightCancel) && !errkit.Is(result.error, context.Canceled) && !errkit.Is(result.error, context.DeadlineExceeded) { errs = append(errs, result) } } @@ -293,7 +293,7 @@ func (d *DialWrap) dialAllParallel(ctx context.Context) ([]*dialResult, error) { } // If not inflight cancel then it is a permanent error (port closed/filtered) - if !errkit.Is(finalErr, ErrInflightCancel) { + if !errkit.Is(finalErr, ErrInflightCancel) && !errkit.Is(finalErr, context.Canceled) && !errkit.Is(finalErr, context.DeadlineExceeded) { return nil, errkit.Append(ErrPortClosedOrFiltered, finalErr) } diff --git a/fastdialer/utils/dialwrap_test.go b/fastdialer/utils/dialwrap_test.go new file mode 100644 index 0000000..f19e6cc --- /dev/null +++ b/fastdialer/utils/dialwrap_test.go @@ -0,0 +1,174 @@ +package utils + +import ( + "context" + "fmt" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/projectdiscovery/utils/errkit" +) + +// TestDialAllParallel_DeadlineExceeded verifies that when all dials fail with +// context.DeadlineExceeded the error is NOT wrapped as ErrPortClosedOrFiltered. +// This prevents poisoning the dial cache with a permanent error when the real +// cause was a transient timeout. +func TestDialAllParallel_DeadlineExceeded(t *testing.T) { + t.Parallel() + + // Use a very short deadline so all dials fail with DeadlineExceeded. + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + // Let the deadline expire before dialing. + time.Sleep(5 * time.Millisecond) + + dw, err := NewDialWrap( + &net.Dialer{Timeout: 1 * time.Millisecond}, + []string{"192.0.2.1"}, // RFC 5737 TEST-NET, will never connect + "tcp", + "192.0.2.1:12345", + "12345", + ) + if err != nil { + t.Fatal(err) + } + + _, dialErr := dw.dialAllParallel(ctx) + if dialErr == nil { + t.Fatal("expected an error from dialAllParallel with expired context, got nil") + } + + if errkit.Is(dialErr, ErrPortClosedOrFiltered) { + t.Fatalf("DeadlineExceeded must not be classified as ErrPortClosedOrFiltered, got: %v", dialErr) + } +} + +// TestDialAllParallel_ContextCanceled verifies that when all dials fail with +// context.Canceled the error is NOT wrapped as ErrPortClosedOrFiltered. +func TestDialAllParallel_ContextCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel immediately so all dials get context.Canceled. + cancel() + + dw, err := NewDialWrap( + &net.Dialer{Timeout: 1 * time.Millisecond}, + []string{"192.0.2.1"}, + "tcp", + "192.0.2.1:12345", + "12345", + ) + if err != nil { + t.Fatal(err) + } + + _, dialErr := dw.dialAllParallel(ctx) + if dialErr == nil { + t.Fatal("expected an error from dialAllParallel with canceled context, got nil") + } + + if errkit.Is(dialErr, ErrPortClosedOrFiltered) { + t.Fatalf("context.Canceled must not be classified as ErrPortClosedOrFiltered, got: %v", dialErr) + } +} + +// TestDialAllParallel_RealConnectionRefused verifies that a genuine +// connection-refused error IS still classified as ErrPortClosedOrFiltered. +func TestDialAllParallel_RealConnectionRefused(t *testing.T) { + t.Parallel() + + // Bind a listener and immediately close it to guarantee a refused port. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + port := ln.Addr().(*net.TCPAddr).Port + ln.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + dw, err := NewDialWrap( + &net.Dialer{Timeout: 2 * time.Second}, + []string{"127.0.0.1"}, + "tcp", + "127.0.0.1", + fmt.Sprintf("%d", port), + ) + if err != nil { + t.Fatal(err) + } + + _, dialErr := dw.dialAllParallel(ctx) + if dialErr == nil { + t.Fatal("expected an error from dialAllParallel to a refused port, got nil") + } + + if !errkit.Is(dialErr, ErrPortClosedOrFiltered) { + t.Fatalf("connection refused should still be classified as ErrPortClosedOrFiltered, got: %v", dialErr) + } +} + +// TestDialContext_DeadlineExceededNotCached verifies that DialContext with a +// deadline-exceeded first connection does not permanently store the error +// as ErrPortClosedOrFiltered, so a subsequent caller is not poisoned. +func TestDialContext_DeadlineExceededNotCached(t *testing.T) { + t.Parallel() + + // Start a listener that accepts but never responds (simulates slow host). + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + port := ln.Addr().(*net.TCPAddr).Port + + var accepted atomic.Int32 + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + accepted.Add(1) + go func(c net.Conn) { + time.Sleep(10 * time.Second) + c.Close() + }(conn) + } + }() + + dw, err := NewDialWrap( + &net.Dialer{Timeout: 50 * time.Millisecond}, + []string{"127.0.0.1"}, + "tcp", + "127.0.0.1", + fmt.Sprintf("%d", port), + ) + if err != nil { + t.Fatal(err) + } + + // First call: use an already-expired context. + expiredCtx, cancel1 := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel1() + time.Sleep(5 * time.Millisecond) + + _, err1 := dw.DialContext(expiredCtx, "", "") + if err1 == nil { + t.Fatal("expected error from DialContext with expired context") + } + + // The stored error must NOT be ErrPortClosedOrFiltered. + dw.firstConnCond.L.Lock() + storedErr := dw.err + dw.firstConnCond.L.Unlock() + + if storedErr != nil && errkit.Is(storedErr, ErrPortClosedOrFiltered) { + t.Fatalf("deadline-exceeded error must not be cached as ErrPortClosedOrFiltered: %v", storedErr) + } +} \ No newline at end of file