From c2dcb1d332ad141583eb55e5e884f4ca0780967d Mon Sep 17 00:00:00 2001 From: ChenXuzheng <1092889706@qq.com> Date: Sat, 28 Oct 2023 19:05:11 +0800 Subject: [PATCH] fix: use non-block dns hijack handler --- internal/zcdns/local_server.go | 2 ++ main.go | 2 +- service/dns.go | 21 ++++++++++-- stack/tun/stack.go | 61 ++++++++++++++++++++++++++-------- stack/tun/stack_darwin.go | 1 + stack/tun/stack_windows.go | 2 +- 6 files changed, 71 insertions(+), 18 deletions(-) diff --git a/internal/zcdns/local_server.go b/internal/zcdns/local_server.go index ca531bf..2dcb380 100644 --- a/internal/zcdns/local_server.go +++ b/internal/zcdns/local_server.go @@ -3,8 +3,10 @@ package zcdns import ( "context" "github.com/miekg/dns" + "net" ) type LocalServer interface { HandleDnsMsg(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) + CheckDnsHijack(dstIP net.IP) bool } diff --git a/main.go b/main.go index dc60e16..c0af858 100644 --- a/main.go +++ b/main.go @@ -111,7 +111,7 @@ func main() { vpnResolver.SetPermanentDNS(customDns.HostName, ipAddr) log.Printf("Add custom DNS: %s -> %s\n", customDns.HostName, customDns.IP) } - localResolver := service.NewDnsServer(vpnResolver) + localResolver := service.NewDnsServer(vpnResolver, []string{conf.ZJUDNSServer, conf.SecondaryDNSServer}) vpnStack.SetupResolve(localResolver) go vpnStack.Run() diff --git a/service/dns.go b/service/dns.go index daf06ee..6810d1d 100644 --- a/service/dns.go +++ b/service/dns.go @@ -6,10 +6,12 @@ import ( "github.com/miekg/dns" "github.com/mythologyli/zju-connect/log" "github.com/mythologyli/zju-connect/resolve" + "net" ) type DNSServer struct { resolver *resolve.Resolver + localDNS []net.IP } func (d DNSServer) serveDNSRequest(w dns.ResponseWriter, r *dns.Msg) { @@ -31,6 +33,15 @@ func (d DNSServer) HandleDnsMsg(ctx context.Context, requestMsg *dns.Msg) (*dns. return resMsg, err } +func (d DNSServer) CheckDnsHijack(dstIP net.IP) bool { + for _, ip := range d.localDNS { + if ip.Equal(dstIP) { + return false + } + } + return true +} + func (d DNSServer) handleSingleDNSResolve(ctx context.Context, requestMsg *dns.Msg, resMsg *dns.Msg) error { switch requestMsg.Opcode { case dns.OpcodeQuery: @@ -65,8 +76,14 @@ func (d DNSServer) handleSingleDNSResolve(ctx context.Context, requestMsg *dns.M return nil } -func NewDnsServer(resolver *resolve.Resolver) DNSServer { - return DNSServer{resolver: resolver} +func NewDnsServer(resolver *resolve.Resolver, dnsServers []string) DNSServer { + netIPs := make([]net.IP, len(dnsServers)) + for _, dnsServer := range dnsServers { + if net.ParseIP(dnsServer) != nil { + netIPs = append(netIPs, net.ParseIP(dnsServer)) + } + } + return DNSServer{resolver: resolver, localDNS: netIPs} } func ServeDNS(bindAddr string, dnsServer DNSServer) { diff --git a/stack/tun/stack.go b/stack/tun/stack.go index 35a0026..45340f8 100644 --- a/stack/tun/stack.go +++ b/stack/tun/stack.go @@ -57,6 +57,7 @@ func (s *Stack) Run() { continue } + // whether this should be a blocking operation? packet := buf[:n] switch ipVersion := packet[0] >> 4; ipVersion { case zctcpip.IPv4Version: @@ -86,6 +87,8 @@ func (s *Stack) processIPV4(packet zctcpip.IPv4Packet) error { } func (s *Stack) processIPV4TCP(packet zctcpip.IPv4Packet, tcpPacket zctcpip.TCPPacket) error { + log.DebugPrintf("receive tcp %s:%d -> %s:%d", packet.SourceIP(), tcpPacket.SourcePort(), packet.DestinationIP(), tcpPacket.DestinationPort()) + if !packet.DestinationIP().IsGlobalUnicast() { return s.endpoint.Write(packet) } @@ -97,19 +100,18 @@ func (s *Stack) processIPV4TCP(packet zctcpip.IPv4Packet, tcpPacket zctcpip.TCPP } func (s *Stack) processIPV4UDP(packet zctcpip.IPv4Packet, udpPacket zctcpip.UDPPacket) error { + log.DebugPrintf("receive udp %s:%d -> %s:%d", packet.SourceIP(), udpPacket.SourcePort(), packet.DestinationIP(), udpPacket.DestinationPort()) + if !packet.DestinationIP().IsGlobalUnicast() { return s.endpoint.Write(packet) } - log.Printf("receive %s:%d -> %s:%d udp", packet.SourceIP(), udpPacket.SourcePort(), packet.DestinationIP(), udpPacket.DestinationPort()) if s.shouldHijackUDPDns(packet, udpPacket) { - log.Printf("hijack %s:%d -> %s:%d dns query", packet.SourceIP(), udpPacket.SourcePort(), packet.DestinationIP(), udpPacket.DestinationPort()) - msg := dns.Msg{} - if err := msg.Unpack(udpPacket.Payload()); err != nil { - return err - } - resMsg, err := s.resolve.HandleDnsMsg(context.Background(), &msg) - fmt.Println(resMsg.String(), err) + newPacket := make(zctcpip.IPv4Packet, len(packet)) + copy(newPacket, packet) + newUdpPacket := zctcpip.UDPPacket(newPacket.Payload()) + // need to be non-blocking + go s.doHijackUDPDns(newPacket, newUdpPacket) return nil } @@ -121,7 +123,7 @@ func (s *Stack) processIPV4UDP(packet zctcpip.IPv4Packet, udpPacket zctcpip.UDPP } func (s *Stack) processIPV4ICMP(packet zctcpip.IPv4Packet, icmpHeader zctcpip.ICMPPacket) error { - log.Printf("icmp %s -> %s", packet.SourceIP(), packet.DestinationIP()) + log.DebugPrintf("receive icmp %s -> %s", packet.SourceIP(), packet.DestinationIP()) if icmpHeader.Type() != zctcpip.ICMPTypePingRequest || icmpHeader.Code() != 0 { return nil } @@ -141,12 +143,43 @@ func (s *Stack) shouldHijackUDPDns(ipHeader zctcpip.IPv4Packet, udpHeader zctcpi if udpHeader.DestinationPort() != 53 { return false } - if ipHeader.SourceIP().Equal(s.endpoint.ip) { - return false + return s.resolve.CheckDnsHijack(ipHeader.DestinationIP()) +} + +func (s *Stack) doHijackUDPDns(ipHeader zctcpip.IPv4Packet, udpHeader zctcpip.UDPPacket) { + log.Printf("hijack dns %s:%d -> %s:%d", ipHeader.SourceIP(), udpHeader.SourcePort(), ipHeader.DestinationIP(), udpHeader.DestinationPort()) + msg := dns.Msg{} + if err := msg.Unpack(udpHeader.Payload()); err != nil { + log.Printf("unpack dns msg error: %v", err) + return } - if !ipHeader.DestinationIP().IsGlobalUnicast() { - return false + resMsg, err := s.resolve.HandleDnsMsg(context.Background(), &msg) + if err != nil { + log.Printf("hijack dns error: %v", err) + return + } + + resByte, err := resMsg.Pack() + if err != nil { + log.Printf("pack dns msg error: %v", err) + return } - return true + totalLen := int(ipHeader.HeaderLen()) + zctcpip.UDPHeaderSize + len(resByte) + + newPacket := make(zctcpip.IPv4Packet, totalLen) + copy(newPacket, ipHeader[:ipHeader.HeaderLen()]) + newPacket.SetTotalLength(uint16(totalLen)) + newPacket.SetSourceIP(ipHeader.DestinationIP()) + newPacket.SetDestinationIP(ipHeader.SourceIP()) + + newUDPHeader := zctcpip.UDPPacket(newPacket.Payload()) + newUDPHeader.SetSourcePort(udpHeader.DestinationPort()) + newUDPHeader.SetDestinationPort(udpHeader.SourcePort()) + newUDPHeader.SetLength(zctcpip.UDPHeaderSize + uint16(len(resByte))) + copy(newUDPHeader.Payload(), resByte) + + newUDPHeader.ResetChecksum(newPacket.PseudoSum()) + newPacket.ResetChecksum() + _ = s.endpoint.Write(newPacket) } diff --git a/stack/tun/stack_darwin.go b/stack/tun/stack_darwin.go index 6c6e3f9..a3161a7 100644 --- a/stack/tun/stack_darwin.go +++ b/stack/tun/stack_darwin.go @@ -7,6 +7,7 @@ import ( "golang.org/x/sys/unix" "net" "os/exec" + "strconv" "syscall" ) diff --git a/stack/tun/stack_windows.go b/stack/tun/stack_windows.go index d86334e..dd02298 100644 --- a/stack/tun/stack_windows.go +++ b/stack/tun/stack_windows.go @@ -73,7 +73,7 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S return nil, err } - dev, err := tun.CreateTUNWithRequestedGUID(interfaceName, &guid, MTU) + dev, err := tun.CreateTUNWithRequestedGUID(interfaceName, &guid, int(MTU)) if err != nil { return nil, err }