Skip to content

Commit

Permalink
fix: use non-block dns hijack handler
Browse files Browse the repository at this point in the history
  • Loading branch information
cxz66666 committed Oct 28, 2023
1 parent a27fb13 commit c2dcb1d
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 18 deletions.
2 changes: 2 additions & 0 deletions internal/zcdns/local_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package zcdns
import (
"context"
"github.com/miekg/dns"
"net"
)

type LocalServer interface {
HandleDnsMsg(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
CheckDnsHijack(dstIP net.IP) bool
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func main() {
vpnResolver.SetPermanentDNS(customDns.HostName, ipAddr)
log.Printf("Add custom DNS: %s -> %s\n", customDns.HostName, customDns.IP)
}
localResolver := service.NewDnsServer(vpnResolver)
localResolver := service.NewDnsServer(vpnResolver, []string{conf.ZJUDNSServer, conf.SecondaryDNSServer})
vpnStack.SetupResolve(localResolver)

go vpnStack.Run()
Expand Down
21 changes: 19 additions & 2 deletions service/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"github.com/miekg/dns"
"github.com/mythologyli/zju-connect/log"
"github.com/mythologyli/zju-connect/resolve"
"net"
)

type DNSServer struct {
resolver *resolve.Resolver
localDNS []net.IP
}

func (d DNSServer) serveDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
Expand All @@ -31,6 +33,15 @@ func (d DNSServer) HandleDnsMsg(ctx context.Context, requestMsg *dns.Msg) (*dns.
return resMsg, err
}

func (d DNSServer) CheckDnsHijack(dstIP net.IP) bool {
for _, ip := range d.localDNS {
if ip.Equal(dstIP) {
return false
}
}
return true
}

func (d DNSServer) handleSingleDNSResolve(ctx context.Context, requestMsg *dns.Msg, resMsg *dns.Msg) error {
switch requestMsg.Opcode {
case dns.OpcodeQuery:
Expand Down Expand Up @@ -65,8 +76,14 @@ func (d DNSServer) handleSingleDNSResolve(ctx context.Context, requestMsg *dns.M
return nil
}

func NewDnsServer(resolver *resolve.Resolver) DNSServer {
return DNSServer{resolver: resolver}
func NewDnsServer(resolver *resolve.Resolver, dnsServers []string) DNSServer {
netIPs := make([]net.IP, len(dnsServers))
for _, dnsServer := range dnsServers {
if net.ParseIP(dnsServer) != nil {
netIPs = append(netIPs, net.ParseIP(dnsServer))
}
}
return DNSServer{resolver: resolver, localDNS: netIPs}
}

func ServeDNS(bindAddr string, dnsServer DNSServer) {
Expand Down
61 changes: 47 additions & 14 deletions stack/tun/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (s *Stack) Run() {
continue
}

// whether this should be a blocking operation?
packet := buf[:n]
switch ipVersion := packet[0] >> 4; ipVersion {
case zctcpip.IPv4Version:
Expand Down Expand Up @@ -86,6 +87,8 @@ func (s *Stack) processIPV4(packet zctcpip.IPv4Packet) error {
}

func (s *Stack) processIPV4TCP(packet zctcpip.IPv4Packet, tcpPacket zctcpip.TCPPacket) error {
log.DebugPrintf("receive tcp %s:%d -> %s:%d", packet.SourceIP(), tcpPacket.SourcePort(), packet.DestinationIP(), tcpPacket.DestinationPort())

if !packet.DestinationIP().IsGlobalUnicast() {
return s.endpoint.Write(packet)
}
Expand All @@ -97,19 +100,18 @@ func (s *Stack) processIPV4TCP(packet zctcpip.IPv4Packet, tcpPacket zctcpip.TCPP
}

func (s *Stack) processIPV4UDP(packet zctcpip.IPv4Packet, udpPacket zctcpip.UDPPacket) error {
log.DebugPrintf("receive udp %s:%d -> %s:%d", packet.SourceIP(), udpPacket.SourcePort(), packet.DestinationIP(), udpPacket.DestinationPort())

if !packet.DestinationIP().IsGlobalUnicast() {
return s.endpoint.Write(packet)
}
log.Printf("receive %s:%d -> %s:%d udp", packet.SourceIP(), udpPacket.SourcePort(), packet.DestinationIP(), udpPacket.DestinationPort())

if s.shouldHijackUDPDns(packet, udpPacket) {
log.Printf("hijack %s:%d -> %s:%d dns query", packet.SourceIP(), udpPacket.SourcePort(), packet.DestinationIP(), udpPacket.DestinationPort())
msg := dns.Msg{}
if err := msg.Unpack(udpPacket.Payload()); err != nil {
return err
}
resMsg, err := s.resolve.HandleDnsMsg(context.Background(), &msg)
fmt.Println(resMsg.String(), err)
newPacket := make(zctcpip.IPv4Packet, len(packet))
copy(newPacket, packet)
newUdpPacket := zctcpip.UDPPacket(newPacket.Payload())
// need to be non-blocking
go s.doHijackUDPDns(newPacket, newUdpPacket)
return nil
}

Expand All @@ -121,7 +123,7 @@ func (s *Stack) processIPV4UDP(packet zctcpip.IPv4Packet, udpPacket zctcpip.UDPP
}

func (s *Stack) processIPV4ICMP(packet zctcpip.IPv4Packet, icmpHeader zctcpip.ICMPPacket) error {
log.Printf("icmp %s -> %s", packet.SourceIP(), packet.DestinationIP())
log.DebugPrintf("receive icmp %s -> %s", packet.SourceIP(), packet.DestinationIP())
if icmpHeader.Type() != zctcpip.ICMPTypePingRequest || icmpHeader.Code() != 0 {
return nil
}
Expand All @@ -141,12 +143,43 @@ func (s *Stack) shouldHijackUDPDns(ipHeader zctcpip.IPv4Packet, udpHeader zctcpi
if udpHeader.DestinationPort() != 53 {
return false
}
if ipHeader.SourceIP().Equal(s.endpoint.ip) {
return false
return s.resolve.CheckDnsHijack(ipHeader.DestinationIP())
}

func (s *Stack) doHijackUDPDns(ipHeader zctcpip.IPv4Packet, udpHeader zctcpip.UDPPacket) {
log.Printf("hijack dns %s:%d -> %s:%d", ipHeader.SourceIP(), udpHeader.SourcePort(), ipHeader.DestinationIP(), udpHeader.DestinationPort())
msg := dns.Msg{}
if err := msg.Unpack(udpHeader.Payload()); err != nil {
log.Printf("unpack dns msg error: %v", err)
return
}
if !ipHeader.DestinationIP().IsGlobalUnicast() {
return false
resMsg, err := s.resolve.HandleDnsMsg(context.Background(), &msg)
if err != nil {
log.Printf("hijack dns error: %v", err)
return
}

resByte, err := resMsg.Pack()
if err != nil {
log.Printf("pack dns msg error: %v", err)
return
}

return true
totalLen := int(ipHeader.HeaderLen()) + zctcpip.UDPHeaderSize + len(resByte)

newPacket := make(zctcpip.IPv4Packet, totalLen)
copy(newPacket, ipHeader[:ipHeader.HeaderLen()])
newPacket.SetTotalLength(uint16(totalLen))
newPacket.SetSourceIP(ipHeader.DestinationIP())
newPacket.SetDestinationIP(ipHeader.SourceIP())

newUDPHeader := zctcpip.UDPPacket(newPacket.Payload())
newUDPHeader.SetSourcePort(udpHeader.DestinationPort())
newUDPHeader.SetDestinationPort(udpHeader.SourcePort())
newUDPHeader.SetLength(zctcpip.UDPHeaderSize + uint16(len(resByte)))
copy(newUDPHeader.Payload(), resByte)

newUDPHeader.ResetChecksum(newPacket.PseudoSum())
newPacket.ResetChecksum()
_ = s.endpoint.Write(newPacket)
}
1 change: 1 addition & 0 deletions stack/tun/stack_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"golang.org/x/sys/unix"
"net"
"os/exec"
"strconv"
"syscall"
)

Expand Down
2 changes: 1 addition & 1 deletion stack/tun/stack_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S
return nil, err
}

dev, err := tun.CreateTUNWithRequestedGUID(interfaceName, &guid, MTU)
dev, err := tun.CreateTUNWithRequestedGUID(interfaceName, &guid, int(MTU))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit c2dcb1d

Please sign in to comment.