From d7f543e58fa4619596a30ec5437b00778f2cb095 Mon Sep 17 00:00:00 2001 From: phuslu Date: Sat, 26 Oct 2024 23:54:12 +0800 Subject: [PATCH] add tls dialer --- client.go | 4 +-- client_dialer.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 44 +++++++++++++++++++-------- 3 files changed, 112 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index c8e06ca..9693674 100644 --- a/client.go +++ b/client.go @@ -56,13 +56,13 @@ func (c *Client) exchange(ctx context.Context, req, resp *Message) error { _, err = conn.Write(req.Raw) if err != nil { - return nil + return err } resp.Raw = resp.Raw[:cap(resp.Raw)] n, err := conn.Read(resp.Raw) if err != nil { - return nil + return err } resp.Raw = resp.Raw[:n] diff --git a/client_dialer.go b/client_dialer.go index 163fb7a..13bef07 100644 --- a/client_dialer.go +++ b/client_dialer.go @@ -2,6 +2,7 @@ package fastdns import ( "context" + "crypto/tls" "fmt" "io" "net" @@ -66,6 +67,84 @@ func (d *UDPDialer) put(conn net.Conn) { d.conns <- conn } +// TLSDialer is a custom dialer for creating TLS connections. +// It manages a pool of connections to optimize performance in scenarios +// where multiple TLS connections to the same server are required. +type TLSDialer struct { + // Addr specifies the remote TLS address that the dialer will connect to. + Addr *net.TCPAddr + + TLSConfig *tls.Config + + // Timeout specifies the maximum duration for a query to complete. + // If a query exceeds this duration, it will result in a timeout error. + Timeout time.Duration + + // MaxConns limits the maximum number of TLS connections that can be created + // and reused. Once this limit is reached, no new connections will be made. + // If not set, use 64 as default. + MaxConns uint16 + + once sync.Once + conns chan net.Conn +} + +func (d *TLSDialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { + return d.get() +} + +func (d *TLSDialer) get() (_ net.Conn, err error) { + d.once.Do(func() { + if d.MaxConns == 0 { + d.MaxConns = 64 + } + d.conns = make(chan net.Conn, d.MaxConns) + for range d.MaxConns { + var c *tls.Conn + c, err = tls.DialWithDialer(&net.Dialer{Timeout: d.Timeout}, "tcp", d.Addr.String(), d.TLSConfig) + if err != nil { + break + } + d.conns <- &tlsConn{c, make([]byte, 0, 1024)} + } + }) + + if err != nil { + return + } + + c := <-d.conns + + return c, nil +} + +func (d *TLSDialer) put(conn net.Conn) { + d.conns <- conn +} + +type tlsConn struct { + *tls.Conn + buffer []byte +} + +func (c *tlsConn) Write(b []byte) (int, error) { + n := len(b) + c.buffer = append(c.buffer[:0], byte(n>>8), byte(n&0xFF)) + c.buffer = append(c.buffer, b...) + _, err := c.Conn.Write(c.buffer) + return n, err +} + +func (c *tlsConn) Read(b []byte) (n int, err error) { + c.buffer = c.buffer[:cap(c.buffer)] + n, err = c.Conn.Read(c.buffer) + if err != nil { + return + } + copy(b, c.buffer[2:n]) + return n - 2, nil +} + // HTTPDialer is a custom dialer for creating HTTP connections. // It allows sending HTTP requests with a specified endpoint, user agent, and transport configuration. type HTTPDialer struct { diff --git a/client_test.go b/client_test.go index 4de82cf..cb14a40 100644 --- a/client_test.go +++ b/client_test.go @@ -87,7 +87,14 @@ func TestClientLookup(t *testing.T) { Addr: "1.1.1.1:53", Dialer: &UDPDialer{ Addr: func() (u *net.UDPAddr) { u, _ = net.ResolveUDPAddr("udp", "1.1.1.1:53"); return }(), - MaxConns: 1000, + MaxConns: 16, + }, + }, + { + Addr: "1.1.1.1:853", + Dialer: &TLSDialer{ + Addr: func() (u *net.TCPAddr) { u, _ = net.ResolveTCPAddr("tcp", "1.1.1.1:853"); return }(), + MaxConns: 16, }, }, { @@ -163,18 +170,6 @@ func BenchmarkResolverPureGo(b *testing.B) { }) } -func BenchmarkResolverCGO(b *testing.B) { - resolver := net.Resolver{PreferGo: false} - - b.ReportAllocs() - b.ResetTimer() - b.RunParallel(func(b *testing.PB) { - for b.Next() { - _, _ = resolver.LookupNetIP(context.Background(), "ip4", "www.google.com") - } - }) -} - func BenchmarkResolverFastdnsDefault(b *testing.B) { server := "8.8.8.8:53" if data, err := os.ReadFile("/etc/resolv.conf"); err == nil { @@ -260,6 +255,29 @@ func BenchmarkResolverFastdnsUDPAppend(b *testing.B) { }) } +// func BenchmarkResolverFastdnsTLS(b *testing.B) { +// server := "1.1.1.1:853" + +// resolver := &Client{ +// Addr: server, +// Dialer: &TLSDialer{ +// Addr: func() (u *net.TCPAddr) { u, _ = net.ResolveTCPAddr("tcp", server); return }(), +// MaxConns: 1024, +// }, +// } + +// b.ReportAllocs() +// b.ResetTimer() +// b.RunParallel(func(pb *testing.PB) { +// for pb.Next() { +// ips, err := resolver.LookupNetIP(context.Background(), "ip4", "www.google.com") +// if len(ips) == 0 || err != nil { +// b.Errorf("fastdns return ips: %+v error: %+v", ips, err) +// } +// } +// }) +// } + func BenchmarkResolverFastdnsHTTP(b *testing.B) { server := "1.1.1.1"