Skip to content

Commit

Permalink
refactor: dns query
Browse files Browse the repository at this point in the history
  • Loading branch information
rkonfj committed May 29, 2023
1 parent bf711d6 commit 8003ac1
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 91 deletions.
77 changes: 17 additions & 60 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/miekg/dns"
D "github.com/rkonfj/toh/dns"
"github.com/rkonfj/toh/server/api"
"github.com/rkonfj/toh/spec"
"github.com/sirupsen/logrus"
Expand All @@ -30,6 +31,7 @@ type TohClient struct {
serverIPs []net.IP
serverPort string
dnsClient *dns.Client
resolver *D.Resolver
}

type Options struct {
Expand All @@ -47,23 +49,25 @@ func NewTohClient(options Options) (*TohClient, error) {
dnsClient: &dns.Client{},
connIdleTimeout: 75 * time.Second,
}
dialer := net.Dialer{}
c.resolver = &D.Resolver{Servers: D.DefaultResolver.Servers, Exchange: c.dnsExchange}
c.netDial = func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if len(c.serverIPs) == 0 {
var host string
host, c.serverPort, err = net.SplitHostPort(addr)
if err != nil {
return
}
c.serverIPs, err = c.directLookupIP(host, dns.TypeA)
c.serverIPs, err = D.LookupIP4(host)
if err == spec.ErrDNSTypeANotFound {
c.serverIPs, err = c.directLookupIP(host, dns.TypeAAAA)
c.serverIPs, err = D.LookupIP6(host)
}
if err != nil {
err = spec.ErrDNSRecordNotFound
return
}
}
return (&net.Dialer{}).DialContext(ctx, network,
return dialer.DialContext(ctx, network,
net.JoinHostPort(c.serverIPs[rand.Intn(len(c.serverIPs))].String(), c.serverPort))
}
c.httpClient = &http.Client{
Expand All @@ -75,7 +79,7 @@ func NewTohClient(options Options) (*TohClient, error) {
}

func (c *TohClient) DNSExchange(dnServer string, query *dns.Msg) (resp *dns.Msg, err error) {
return c.dnsExchange(dnServer, query, false)
return c.dnsExchange(dnServer, query)
}

// LookupIP lookup ipv4 and ipv6
Expand All @@ -91,7 +95,7 @@ func (c *TohClient) LookupIP(host string) (ips []net.IP, err error) {
ip6 = append(ip6, _ips...)
}
}()
_ips, e4 := c.lookupIP(host, dns.TypeA, false)
_ips, e4 := c.LookupIP4(host)
if e4 == nil {
ips = append(ips, _ips...)
}
Expand All @@ -105,12 +109,12 @@ func (c *TohClient) LookupIP(host string) (ips []net.IP, err error) {

// LookupIP4 lookup only ipv4
func (c *TohClient) LookupIP4(host string) (ips []net.IP, err error) {
return c.lookupIP(host, dns.TypeA, false)
return c.resolver.LookupIP(host, dns.TypeA)
}

// LookupIP4 lookup only ipv6
func (c *TohClient) LookupIP6(host string) (ips []net.IP, err error) {
return c.lookupIP(host, dns.TypeAAAA, false)
return c.resolver.LookupIP(host, dns.TypeAAAA)
}

func (c *TohClient) DialTCP(ctx context.Context, addr string) (net.Conn, error) {
Expand Down Expand Up @@ -161,66 +165,19 @@ func (c *TohClient) Stats() (s *api.Stats, err error) {
return
}

func (c *TohClient) dnsExchange(dnServer string, query *dns.Msg, direct bool) (resp *dns.Msg, err error) {
func (c *TohClient) dnsExchange(dnServer string, query *dns.Msg) (resp *dns.Msg, err error) {
dnsLookupCtx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
defer cancel()
if direct {
resp, _, err = c.dnsClient.ExchangeContext(dnsLookupCtx, query, dnServer)
} else {
conn, _err := c.DialUDP(dnsLookupCtx, dnServer)
if _err != nil {
err = fmt.Errorf("dial error: %s", _err.Error())
return
}

defer conn.Close()
resp, _, err = c.dnsClient.ExchangeWithConn(query, &dns.Conn{Conn: &spec.PacketConnWrapper{Conn: conn}})
}
return
}

func (c *TohClient) lookupIP(host string, t uint16, direct bool) (ips []net.IP, err error) {
ip := net.ParseIP(host)
if ip != nil {
ips = append(ips, ip)
return
}
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(host), t)
var resp *dns.Msg
for _, dnServer := range []string{"8.8.8.8:53", "223.5.5.5:53"} {
resp, err = c.dnsExchange(dnServer, query, direct)
if err == nil {
break
}
}
if err != nil {
conn, _err := c.DialUDP(dnsLookupCtx, dnServer)
if _err != nil {
err = fmt.Errorf("dial error: %s", _err.Error())
return
}
for _, a := range resp.Answer {
if a.Header().Rrtype == dns.TypeA {
ips = append(ips, a.(*dns.A).A)
}
if a.Header().Rrtype == dns.TypeAAAA {
ips = append(ips, a.(*dns.AAAA).AAAA)
}
}
if len(ips) == 0 {
if t == dns.TypeA {
err = spec.ErrDNSTypeANotFound
} else if t == dns.TypeAAAA {
err = spec.ErrDNSTypeAAAANotFound
} else {
err = fmt.Errorf("resolve %s : no type %s was found", host, dns.Type(t))
}
}
defer conn.Close()
resp, _, err = c.dnsClient.ExchangeWithConn(query, &dns.Conn{Conn: &spec.PacketConnWrapper{Conn: conn}})
return
}

func (c *TohClient) directLookupIP(host string, t uint16) (ips []net.IP, err error) {
return c.lookupIP(host, t, true)
}

func (c *TohClient) dial(ctx context.Context, network, addr string) (
wsConn *wsConn, remoteAddr net.Addr, err error) {
handshake := http.Header{}
Expand Down
70 changes: 39 additions & 31 deletions cmd/s5/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"os"
Expand Down Expand Up @@ -444,30 +445,33 @@ func (s *S5Server) watchSignal() {
}
}

func (s *S5Server) reloadRuleset() {
for _, g := range s.groups {
err := g.ruleset.Reload()
if err != nil {
logrus.Error(err)
}
}
for _, s := range s.servers {
err := s.ruleset.Reload()
if err != nil {
logrus.Error(err)
}
}
ruleset.ResetCache()
s.printRulesetStats()
}

func (s *S5Server) localAddrFamilyDetection() {
if s.opts.Cfg.LocalNet == nil {
return
}
if len(s.opts.Cfg.LocalNet.AddrFamilyDetectURL) == 0 {
return
}
dialer := net.Dialer{}
httpIPv4 := http.Client{
Timeout: 6 * time.Second,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp4", addr)
},
},
}
httpIPv6 := http.Client{
Timeout: 6 * time.Second,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp6", addr)
},
},
}

httpIPv4 := newHTTPClient(D.LookupIP4)
httpIPv6 := newHTTPClient(D.LookupIP6)

for {
var err error
Expand Down Expand Up @@ -497,21 +501,25 @@ func (s *S5Server) localAddrFamilyDetection() {
}
}

func (s *S5Server) reloadRuleset() {
for _, g := range s.groups {
err := g.ruleset.Reload()
if err != nil {
logrus.Error(err)
}
}
for _, s := range s.servers {
err := s.ruleset.Reload()
if err != nil {
logrus.Error(err)
}
func newHTTPClient(lookupIP func(host string) (ips []net.IP, err error)) *http.Client {
dialer := net.Dialer{}
return &http.Client{
Timeout: 6 * time.Second,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return
}
ips, err := lookupIP(host)
if err != nil {
return
}
return dialer.DialContext(ctx, network,
net.JoinHostPort(ips[rand.Intn(len(ips))].String(), port))
},
},
}
ruleset.ResetCache()
s.printRulesetStats()
}

func selectServer(servers []*Server) *Server {
Expand Down
74 changes: 74 additions & 0 deletions dns/query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package dns

import (
"context"
"fmt"
"net"
"time"

"github.com/miekg/dns"
"github.com/rkonfj/toh/spec"
)

var dnsClient *dns.Client = &dns.Client{}

var DefaultResolver Resolver = Resolver{
Servers: []string{"8.8.8.8:53", "223.5.5.5:53"},
Exchange: func(dnServer string, r *dns.Msg) (resp *dns.Msg, err error) {
dnsLookupCtx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
defer cancel()
resp, _, err = dnsClient.ExchangeContext(dnsLookupCtx, r, dnServer)
return
},
}

type Resolver struct {
Servers []string
Exchange func(dnServer string, r *dns.Msg) (*dns.Msg, error)
}

func LookupIP4(host string) (ips []net.IP, err error) {
return DefaultResolver.LookupIP(host, dns.TypeA)
}

func LookupIP6(host string) (ips []net.IP, err error) {
return DefaultResolver.LookupIP(host, dns.TypeAAAA)
}

func (r *Resolver) LookupIP(host string, t uint16) (ips []net.IP, err error) {
ip := net.ParseIP(host)
if ip != nil {
ips = append(ips, ip)
return
}
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(host), t)
var resp *dns.Msg
for _, dnServer := range r.Servers {
resp, err = r.Exchange(dnServer, query)
if err == nil {
break
}
}
if err != nil {
return
}
for _, a := range resp.Answer {
if a.Header().Rrtype == dns.TypeA {
ips = append(ips, a.(*dns.A).A)
}
if a.Header().Rrtype == dns.TypeAAAA {
ips = append(ips, a.(*dns.AAAA).AAAA)
}
}
if len(ips) == 0 {
if t == dns.TypeA {
err = spec.ErrDNSTypeANotFound
} else if t == dns.TypeAAAA {
err = spec.ErrDNSTypeAAAANotFound
} else {
err = fmt.Errorf("resolve %s : no type %s was found", host, dns.Type(t))
}
}
return
}

0 comments on commit 8003ac1

Please sign in to comment.