From 8003ac1497e1e05c960e2292b130de00b9007a72 Mon Sep 17 00:00:00 2001 From: rkonfj Date: Mon, 29 May 2023 20:37:25 +0800 Subject: [PATCH] refactor: dns query --- client/client.go | 77 +++++++++-------------------------------- cmd/s5/server/server.go | 70 ++++++++++++++++++++----------------- dns/query.go | 74 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+), 91 deletions(-) create mode 100644 dns/query.go diff --git a/client/client.go b/client/client.go index f0b6a56..5408ba9 100644 --- a/client/client.go +++ b/client/client.go @@ -17,6 +17,7 @@ import ( "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "github.com/miekg/dns" + D "github.com/rkonfj/toh/dns" "github.com/rkonfj/toh/server/api" "github.com/rkonfj/toh/spec" "github.com/sirupsen/logrus" @@ -30,6 +31,7 @@ type TohClient struct { serverIPs []net.IP serverPort string dnsClient *dns.Client + resolver *D.Resolver } type Options struct { @@ -47,6 +49,8 @@ func NewTohClient(options Options) (*TohClient, error) { dnsClient: &dns.Client{}, connIdleTimeout: 75 * time.Second, } + dialer := net.Dialer{} + c.resolver = &D.Resolver{Servers: D.DefaultResolver.Servers, Exchange: c.dnsExchange} c.netDial = func(ctx context.Context, network, addr string) (conn net.Conn, err error) { if len(c.serverIPs) == 0 { var host string @@ -54,16 +58,16 @@ func NewTohClient(options Options) (*TohClient, error) { if err != nil { return } - c.serverIPs, err = c.directLookupIP(host, dns.TypeA) + c.serverIPs, err = D.LookupIP4(host) if err == spec.ErrDNSTypeANotFound { - c.serverIPs, err = c.directLookupIP(host, dns.TypeAAAA) + c.serverIPs, err = D.LookupIP6(host) } if err != nil { err = spec.ErrDNSRecordNotFound return } } - return (&net.Dialer{}).DialContext(ctx, network, + return dialer.DialContext(ctx, network, net.JoinHostPort(c.serverIPs[rand.Intn(len(c.serverIPs))].String(), c.serverPort)) } c.httpClient = &http.Client{ @@ -75,7 +79,7 @@ func NewTohClient(options Options) (*TohClient, error) { } func (c *TohClient) DNSExchange(dnServer string, query *dns.Msg) (resp *dns.Msg, err error) { - return c.dnsExchange(dnServer, query, false) + return c.dnsExchange(dnServer, query) } // LookupIP lookup ipv4 and ipv6 @@ -91,7 +95,7 @@ func (c *TohClient) LookupIP(host string) (ips []net.IP, err error) { ip6 = append(ip6, _ips...) } }() - _ips, e4 := c.lookupIP(host, dns.TypeA, false) + _ips, e4 := c.LookupIP4(host) if e4 == nil { ips = append(ips, _ips...) } @@ -105,12 +109,12 @@ func (c *TohClient) LookupIP(host string) (ips []net.IP, err error) { // LookupIP4 lookup only ipv4 func (c *TohClient) LookupIP4(host string) (ips []net.IP, err error) { - return c.lookupIP(host, dns.TypeA, false) + return c.resolver.LookupIP(host, dns.TypeA) } // LookupIP4 lookup only ipv6 func (c *TohClient) LookupIP6(host string) (ips []net.IP, err error) { - return c.lookupIP(host, dns.TypeAAAA, false) + return c.resolver.LookupIP(host, dns.TypeAAAA) } func (c *TohClient) DialTCP(ctx context.Context, addr string) (net.Conn, error) { @@ -161,66 +165,19 @@ func (c *TohClient) Stats() (s *api.Stats, err error) { return } -func (c *TohClient) dnsExchange(dnServer string, query *dns.Msg, direct bool) (resp *dns.Msg, err error) { +func (c *TohClient) dnsExchange(dnServer string, query *dns.Msg) (resp *dns.Msg, err error) { dnsLookupCtx, cancel := context.WithTimeout(context.Background(), 25*time.Second) defer cancel() - if direct { - resp, _, err = c.dnsClient.ExchangeContext(dnsLookupCtx, query, dnServer) - } else { - conn, _err := c.DialUDP(dnsLookupCtx, dnServer) - if _err != nil { - err = fmt.Errorf("dial error: %s", _err.Error()) - return - } - - defer conn.Close() - resp, _, err = c.dnsClient.ExchangeWithConn(query, &dns.Conn{Conn: &spec.PacketConnWrapper{Conn: conn}}) - } - return -} - -func (c *TohClient) lookupIP(host string, t uint16, direct bool) (ips []net.IP, err error) { - ip := net.ParseIP(host) - if ip != nil { - ips = append(ips, ip) - return - } - query := &dns.Msg{} - query.SetQuestion(dns.Fqdn(host), t) - var resp *dns.Msg - for _, dnServer := range []string{"8.8.8.8:53", "223.5.5.5:53"} { - resp, err = c.dnsExchange(dnServer, query, direct) - if err == nil { - break - } - } - if err != nil { + conn, _err := c.DialUDP(dnsLookupCtx, dnServer) + if _err != nil { + err = fmt.Errorf("dial error: %s", _err.Error()) return } - for _, a := range resp.Answer { - if a.Header().Rrtype == dns.TypeA { - ips = append(ips, a.(*dns.A).A) - } - if a.Header().Rrtype == dns.TypeAAAA { - ips = append(ips, a.(*dns.AAAA).AAAA) - } - } - if len(ips) == 0 { - if t == dns.TypeA { - err = spec.ErrDNSTypeANotFound - } else if t == dns.TypeAAAA { - err = spec.ErrDNSTypeAAAANotFound - } else { - err = fmt.Errorf("resolve %s : no type %s was found", host, dns.Type(t)) - } - } + defer conn.Close() + resp, _, err = c.dnsClient.ExchangeWithConn(query, &dns.Conn{Conn: &spec.PacketConnWrapper{Conn: conn}}) return } -func (c *TohClient) directLookupIP(host string, t uint16) (ips []net.IP, err error) { - return c.lookupIP(host, t, true) -} - func (c *TohClient) dial(ctx context.Context, network, addr string) ( wsConn *wsConn, remoteAddr net.Addr, err error) { handshake := http.Header{} diff --git a/cmd/s5/server/server.go b/cmd/s5/server/server.go index 1efe71e..81c23db 100644 --- a/cmd/s5/server/server.go +++ b/cmd/s5/server/server.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net" "net/http" "os" @@ -444,6 +445,23 @@ func (s *S5Server) watchSignal() { } } +func (s *S5Server) reloadRuleset() { + for _, g := range s.groups { + err := g.ruleset.Reload() + if err != nil { + logrus.Error(err) + } + } + for _, s := range s.servers { + err := s.ruleset.Reload() + if err != nil { + logrus.Error(err) + } + } + ruleset.ResetCache() + s.printRulesetStats() +} + func (s *S5Server) localAddrFamilyDetection() { if s.opts.Cfg.LocalNet == nil { return @@ -451,23 +469,9 @@ func (s *S5Server) localAddrFamilyDetection() { if len(s.opts.Cfg.LocalNet.AddrFamilyDetectURL) == 0 { return } - dialer := net.Dialer{} - httpIPv4 := http.Client{ - Timeout: 6 * time.Second, - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.DialContext(ctx, "tcp4", addr) - }, - }, - } - httpIPv6 := http.Client{ - Timeout: 6 * time.Second, - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.DialContext(ctx, "tcp6", addr) - }, - }, - } + + httpIPv4 := newHTTPClient(D.LookupIP4) + httpIPv6 := newHTTPClient(D.LookupIP6) for { var err error @@ -497,21 +501,25 @@ func (s *S5Server) localAddrFamilyDetection() { } } -func (s *S5Server) reloadRuleset() { - for _, g := range s.groups { - err := g.ruleset.Reload() - if err != nil { - logrus.Error(err) - } - } - for _, s := range s.servers { - err := s.ruleset.Reload() - if err != nil { - logrus.Error(err) - } +func newHTTPClient(lookupIP func(host string) (ips []net.IP, err error)) *http.Client { + dialer := net.Dialer{} + return &http.Client{ + Timeout: 6 * time.Second, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return + } + ips, err := lookupIP(host) + if err != nil { + return + } + return dialer.DialContext(ctx, network, + net.JoinHostPort(ips[rand.Intn(len(ips))].String(), port)) + }, + }, } - ruleset.ResetCache() - s.printRulesetStats() } func selectServer(servers []*Server) *Server { diff --git a/dns/query.go b/dns/query.go new file mode 100644 index 0000000..a1bcdc8 --- /dev/null +++ b/dns/query.go @@ -0,0 +1,74 @@ +package dns + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/miekg/dns" + "github.com/rkonfj/toh/spec" +) + +var dnsClient *dns.Client = &dns.Client{} + +var DefaultResolver Resolver = Resolver{ + Servers: []string{"8.8.8.8:53", "223.5.5.5:53"}, + Exchange: func(dnServer string, r *dns.Msg) (resp *dns.Msg, err error) { + dnsLookupCtx, cancel := context.WithTimeout(context.Background(), 25*time.Second) + defer cancel() + resp, _, err = dnsClient.ExchangeContext(dnsLookupCtx, r, dnServer) + return + }, +} + +type Resolver struct { + Servers []string + Exchange func(dnServer string, r *dns.Msg) (*dns.Msg, error) +} + +func LookupIP4(host string) (ips []net.IP, err error) { + return DefaultResolver.LookupIP(host, dns.TypeA) +} + +func LookupIP6(host string) (ips []net.IP, err error) { + return DefaultResolver.LookupIP(host, dns.TypeAAAA) +} + +func (r *Resolver) LookupIP(host string, t uint16) (ips []net.IP, err error) { + ip := net.ParseIP(host) + if ip != nil { + ips = append(ips, ip) + return + } + query := &dns.Msg{} + query.SetQuestion(dns.Fqdn(host), t) + var resp *dns.Msg + for _, dnServer := range r.Servers { + resp, err = r.Exchange(dnServer, query) + if err == nil { + break + } + } + if err != nil { + return + } + for _, a := range resp.Answer { + if a.Header().Rrtype == dns.TypeA { + ips = append(ips, a.(*dns.A).A) + } + if a.Header().Rrtype == dns.TypeAAAA { + ips = append(ips, a.(*dns.AAAA).AAAA) + } + } + if len(ips) == 0 { + if t == dns.TypeA { + err = spec.ErrDNSTypeANotFound + } else if t == dns.TypeAAAA { + err = spec.ErrDNSTypeAAAANotFound + } else { + err = fmt.Errorf("resolve %s : no type %s was found", host, dns.Type(t)) + } + } + return +}