Skip to content

Commit

Permalink
pkg client: dialWS use custom dns query
Browse files Browse the repository at this point in the history
  • Loading branch information
rkonfj committed May 26, 2023
1 parent cf263df commit 9af4341
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
type TohClient struct {
options Options
connIdleTimeout time.Duration
netDial func(ctx context.Context, network, addr string) (conn net.Conn, err error)
httpClient *http.Client
serverIPs []net.IP
serverPort string
Expand All @@ -46,27 +47,28 @@ func NewTohClient(options Options) (*TohClient, error) {
dnsClient: &dns.Client{},
connIdleTimeout: 75 * time.Second,
}
c.netDial = func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if len(c.serverIPs) == 0 {
var host string
host, c.serverPort, err = net.SplitHostPort(addr)
if err != nil {
return
}
c.serverIPs, err = c.directLookupIP(host, dns.TypeA)
if err == spec.ErrDNSTypeANotFound {
c.serverIPs, err = c.directLookupIP(host, dns.TypeAAAA)
}
if err != nil {
err = spec.ErrDNSRecordNotFound
return
}
}
return (&net.Dialer{}).DialContext(ctx, network,
net.JoinHostPort(c.serverIPs[rand.Intn(len(c.serverIPs))].String(), c.serverPort))
}
c.httpClient = &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if len(c.serverIPs) == 0 {
var host string
host, c.serverPort, err = net.SplitHostPort(addr)
if err != nil {
return
}
c.serverIPs, err = c.directLookupIP(host, dns.TypeA)
if err == spec.ErrDNSTypeANotFound {
c.serverIPs, err = c.directLookupIP(host, dns.TypeAAAA)
}
if err != nil {
err = spec.ErrDNSRecordNotFound
return
}
}
return (&net.Dialer{}).DialContext(ctx, network,
net.JoinHostPort(c.serverIPs[rand.Intn(len(c.serverIPs))].String(), c.serverPort))
},
DialContext: c.netDial,
},
}
return c, nil
Expand Down Expand Up @@ -233,7 +235,7 @@ func (c *TohClient) dial(ctx context.Context, network, addr string) (

t1 := time.Now()

wsConn, respHeader, err := dialWS(ctx, c.options.Server, handshake)
wsConn, respHeader, err := c.dialWS(ctx, c.options.Server, handshake)
if err != nil {
return
}
Expand Down Expand Up @@ -281,17 +283,13 @@ func (c *TohClient) newPingLoop(wsConn *wsConn) {
}
}

type wsConn struct {
conn net.Conn
lastActiveTime time.Time
}

func dialWS(ctx context.Context, urlstr string, header http.Header) (
func (c *TohClient) dialWS(ctx context.Context, urlstr string, header http.Header) (
wsc *wsConn, respHeader http.Header, err error) {
respHeader = http.Header{}
var statusCode int
dialer := ws.Dialer{
Header: ws.HandshakeHeaderHTTP(header),
NetDial: c.netDial,
Header: ws.HandshakeHeaderHTTP(header),
OnHeader: func(key, value []byte) (err error) {
respHeader.Add(string(key), string(value))
return
Expand Down Expand Up @@ -329,6 +327,11 @@ func dialWS(ctx context.Context, urlstr string, header http.Header) (
return
}

type wsConn struct {
conn net.Conn
lastActiveTime time.Time
}

func (c *wsConn) Read(ctx context.Context) (b []byte, err error) {
c.lastActiveTime = time.Now()
if dl, ok := ctx.Deadline(); ok {
Expand Down

0 comments on commit 9af4341

Please sign in to comment.