Skip to content

Commit e29553b

Browse files
authored
Merge pull request #34 from cxz66666/use-both-dns
feat: support UDP/TCP DNS select [WIP]
2 parents 9ed7520 + 772b5bd commit e29553b

File tree

1 file changed

+75
-28
lines changed

1 file changed

+75
-28
lines changed

core/socks.go

Lines changed: 75 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"os"
1010
"strconv"
1111
"strings"
12+
"sync"
13+
"time"
1214

1315
"github.com/mythologyli/zju-connect/core/config"
1416

@@ -20,10 +22,31 @@ import (
2022
)
2123

2224
type ZJUDnsResolve struct {
23-
remoteResolver *net.Resolver
25+
remoteUDPResolver *net.Resolver
26+
remoteTCPResolver *net.Resolver
27+
timer *time.Timer
28+
useTCP bool
29+
lock sync.RWMutex
2430
}
2531

26-
func (resolve ZJUDnsResolve) Resolve(ctx context.Context, host string) (context.Context, net.IP, error) {
32+
func (resolve *ZJUDnsResolve) ResolveWithLocal(ctx context.Context, host string) (context.Context, net.IP, error) {
33+
if target, err := net.ResolveIPAddr("ip4", host); err != nil {
34+
log.Printf("Resolve IPv4 addr failed using local DNS: " + host + ". Try IPv6 addr.")
35+
36+
if target, err = net.ResolveIPAddr("ip6", host); err != nil {
37+
log.Printf("Resolve IPv6 addr failed using local DNS: " + host + ". Reject connection.")
38+
return ctx, nil, err
39+
} else {
40+
log.Printf("%s -> %s", host, target.IP.String())
41+
return ctx, target.IP, nil
42+
}
43+
} else {
44+
log.Printf("%s -> %s", host, target.IP.String())
45+
return ctx, target.IP, nil
46+
}
47+
}
48+
49+
func (resolve *ZJUDnsResolve) Resolve(ctx context.Context, host string) (context.Context, net.IP, error) {
2750
if config.IsDnsRuleAvailable() {
2851
if ip, hasDnsRule := config.GetSingleDnsRule(host); hasDnsRule {
2952
ctx = context.WithValue(ctx, "USE_PROXY", true)
@@ -50,43 +73,53 @@ func (resolve ZJUDnsResolve) Resolve(ctx context.Context, host string) (context.
5073
log.Printf("%s -> %s", host, cachedIP.String())
5174
return ctx, cachedIP, nil
5275
} else {
53-
targets, err := resolve.remoteResolver.LookupIP(context.Background(), "ip4", host)
54-
if err != nil {
55-
log.Printf("Resolve IPv4 addr failed using ZJU DNS: " + host + ", using local DNS instead.")
76+
resolve.lock.RLock()
77+
useTCP := resolve.useTCP
78+
resolve.lock.RUnlock()
5679

57-
target, err := net.ResolveIPAddr("ip4", host)
80+
if !useTCP {
81+
targets, err := resolve.remoteUDPResolver.LookupIP(context.Background(), "ip4", host)
5882
if err != nil {
59-
log.Printf("Resolve IPv4 addr failed using local DNS: " + host + ". Try IPv6 addr.")
60-
61-
target, err := net.ResolveIPAddr("ip6", host)
62-
if err != nil {
63-
log.Printf("Resolve IPv6 addr failed using local DNS: " + host + ". Reject connection.")
64-
return ctx, nil, err
83+
if targets, err = resolve.remoteTCPResolver.LookupIP(context.Background(), "ip4", host); err != nil {
84+
// all zju dns failed, so we keep do nothing but use local dns
85+
// host ipv4 and host ipv6 don't set cache
86+
log.Printf("Resolve IPv4 addr failed using ZJU UDP/TCP DNS: " + host + ", using local DNS instead.")
87+
return resolve.ResolveWithLocal(ctx, host)
6588
} else {
66-
log.Printf("%s -> %s", host, target.IP.String())
67-
return ctx, target.IP, nil
89+
resolve.lock.Lock()
90+
resolve.useTCP = true
91+
if resolve.timer == nil {
92+
resolve.timer = time.AfterFunc(10*time.Minute, func() {
93+
resolve.lock.Lock()
94+
resolve.useTCP = false
95+
resolve.timer = nil
96+
resolve.lock.Unlock()
97+
})
98+
}
99+
resolve.lock.Unlock()
68100
}
69-
} else {
70-
log.Printf("%s -> %s", host, target.IP.String())
71-
return ctx, target.IP, nil
72101
}
73-
} else {
102+
// set dns cache if tcp or udp dns success
74103
//TODO: whether we need all dns records? or only 10.0.0.0/8 ?
75104
SetDnsCache(host, targets[0])
76105
log.Printf("%s -> %s", host, targets[0].String())
77106
return ctx, targets[0], nil
107+
} else {
108+
// only try tcp and local dns
109+
if targets, err := resolve.remoteTCPResolver.LookupIP(context.Background(), "ip4", host); err != nil {
110+
log.Printf("Resolve IPv4 addr failed using ZJU TCP DNS: " + host + ", using local DNS instead.")
111+
return resolve.ResolveWithLocal(ctx, host)
112+
} else {
113+
SetDnsCache(host, targets[0])
114+
log.Printf("%s -> %s", host, targets[0].String())
115+
return ctx, targets[0], nil
116+
}
78117
}
79118
}
80119

81120
} else {
82121
// because of OS cache, don't need extra dns memory cache
83-
target, err := net.ResolveIPAddr("ip4", host)
84-
if err != nil {
85-
log.Printf("Resolve IPv4 addr failed using local DNS: " + host + ". Reject connection.")
86-
return ctx, nil, err
87-
} else {
88-
return ctx, target.IP, nil
89-
}
122+
return resolve.ResolveWithLocal(ctx, host)
90123
}
91124
}
92125

@@ -100,7 +133,7 @@ func dialDirect(ctx context.Context, network, addr string) (net.Conn, error) {
100133
}
101134

102135
func ServeSocks5(ipStack *stack.Stack, selfIp []byte, bindAddr string, dnsServer string) {
103-
var remoteResolver = &net.Resolver{
136+
var remoteUDPResolver = &net.Resolver{
104137
PreferGo: true,
105138
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
106139
addrDns := tcpip.FullAddress{
@@ -117,6 +150,17 @@ func ServeSocks5(ipStack *stack.Stack, selfIp []byte, bindAddr string, dnsServer
117150
return gonet.DialUDP(ipStack, &bind, &addrDns, header.IPv4ProtocolNumber)
118151
},
119152
}
153+
var remoteTCPResolver = &net.Resolver{
154+
PreferGo: true,
155+
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
156+
addrDns := tcpip.FullAddress{
157+
NIC: defaultNIC,
158+
Port: uint16(53),
159+
Addr: tcpip.Address(net.ParseIP(dnsServer).To4()),
160+
}
161+
return gonet.DialTCP(ipStack, addrDns, header.IPv4ProtocolNumber)
162+
},
163+
}
120164

121165
var authMethods []socks5.Authenticator
122166
if SocksUser != "" && SocksPasswd != "" {
@@ -244,8 +288,11 @@ func ServeSocks5(ipStack *stack.Stack, selfIp []byte, bindAddr string, dnsServer
244288

245289
server := socks5.NewServer(
246290
socks5.WithAuthMethods(authMethods),
247-
socks5.WithResolver(ZJUDnsResolve{
248-
remoteResolver: remoteResolver,
291+
socks5.WithResolver(&ZJUDnsResolve{
292+
remoteTCPResolver: remoteTCPResolver,
293+
remoteUDPResolver: remoteUDPResolver,
294+
useTCP: false,
295+
timer: nil,
249296
}),
250297
socks5.WithDial(zjuDialer),
251298
socks5.WithLogger(socks5.NewLogger(log.New(os.Stdout, "", log.LstdFlags))),

0 commit comments

Comments
 (0)