diff --git a/backend/api_v1/dns.go b/backend/api_v1/dns.go index caeaaa4..8ef7614 100644 --- a/backend/api_v1/dns.go +++ b/backend/api_v1/dns.go @@ -2,10 +2,13 @@ package api_v1 import ( "encoding/json" + "errors" "fmt" + "net" "reflect" "sort" "strings" + "sync" "github.com/gin-gonic/gin" "github.com/gobeam/stringy" @@ -17,9 +20,6 @@ import ( type dnsParams struct { // Trace is used to define if the DNS record should be traced all the way to the nameserver. Trace bool `form:"trace"` - - // Cache is used to define if the DNS record should only use the DNS cache. - Cache bool `form:"cache"` } // Used to clean the case of things in a value for JSON and remove unwanted keys. @@ -78,252 +78,434 @@ type DNSResponse struct { // Name is used to define the name of the DNS record. Name string `json:"name"` + // DNSServer defines the DNS server which gave this response. + DNSServer string `json:"dnsServer"` + // Value is used to define the value of the DNS record. Value json.RawMessage `json:"value"` + + // Defines the DNS string. + dnsStringify func() string } -type hostnameRecordType struct { - hostname string - recordType uint16 +func godnsLookup(log *zap.Logger, addr string, recordType uint16, hostname string) (*godns.Msg, error) { + // Create the DNS message. + msg := &godns.Msg{} + msg.Id = godns.Id() + msg.RecursionDesired = true + + // DNS servers prefer 1 message per request. Make the question. + msg.Question = []godns.Question{{ + Name: hostname, + Qtype: recordType, + Qclass: godns.StringToClass["IN"], + }} + conn, err := godns.Dial("tcp", addr) + if err != nil { + log.Error("failed to connect to dns server", zap.Error(err)) + return nil, err + } + defer conn.Close() + + // Send the DNS message. + err = conn.WriteMsg(msg) + if err != nil { + return nil, err + } + + // Read the DNS response. + msg, err = conn.ReadMsg() + if err != nil { + log.Error("failed to read from dns server", zap.Error(err)) + } + return msg, err } -func dns(g *gin.RouterGroup, log *zap.Logger, cachedDnsServer string) { - g.GET("/:recordType/:hostname", func(context *gin.Context) { - // Defines if this is JSON. - isJson := context.ContentType() == "application/json" +func findNameserverHostname(log *zap.Logger, addr string, chunks []string) (string, int, uint32, error) { + var msg *godns.Msg + var err error + for i := 0; i < len(chunks); i++ { + // Compile this set of chunks. + hostname := strings.Join(chunks[i:], ".") + "." + + // Do the DNS lookup. + recursionCount := 0 + lookup: + msg, err = godnsLookup(log, addr, godns.StringToType["NS"], hostname) + if err != nil { + continue + } - // Bind the params. - var params dnsParams - if err := context.BindQuery(¶ms); err != nil { - if isJson { - context.JSON(400, map[string]string{ - "message": err.Error(), - }) - } else { - context.String(400, "unable to parse query params: %s", err.Error()) + // Find the answer. + if len(msg.Answer) > 0 { + switch x := msg.Answer[0].(type) { + case *godns.NS: + return x.Ns, i, x.Hdr.Ttl, nil + case *godns.CNAME: + if recursionCount == 50 { + return "", 0, 0, fmt.Errorf("recursion limit reached on %s", hostname) + } + hostname = x.Target + recursionCount++ + goto lookup + default: + return "", 0, 0, errors.New("invalid type for NS record") } - return } + } + if err != nil { + // Errored whilst trying to find NS record. + return "", 0, 0, err + } - // Get the type and hostname from the URL. - recordType := context.Param("recordType") - hostname := context.Param("hostname") - if !strings.HasSuffix(hostname, ".") { - hostname += "." + // Unable to find NS record. + log.Warn("unable to find NS record", zap.String("hostname", strings.Join(chunks, "."))) + return "", 0, 0, nil +} + +func doDnsLookups(log *zap.Logger, dnsServer, recordType string, recursive bool, chunks []string) (map[string][]*DNSResponse, error) { + // Resolve the IP of the DNS server. + var addr string + setServerAddr := func() error { + rawAddr, err := net.ResolveIPAddr("ip", dnsServer) + if err != nil { + return err } + addr = rawAddr.IP.String() + ":53" + return nil + } + if err := setServerAddr(); err != nil { + return nil, err + } - // Make the record type upper case. - recordType = strings.ToUpper(recordType) - originRecordType := recordType + // Keep going through chunks until we get a NS record. + initAddr := addr + oldDnsServer := dnsServer + host, i, ttl, err := findNameserverHostname(log, addr, chunks) + if err != nil { + return nil, err + } + if host == "" { + // Unable to find NS record. + log.Warn("unable to find NS record", zap.String("hostname", strings.Join(chunks, "."))) + } else { + // Turn it into the address. + dnsServer = host + if err = setServerAddr(); err != nil { + return nil, err + } + } + + // If this was a NS lookup that is non-recursive, we have our result here. + if !recursive && recordType == "NS" { + if i == 0 { + // This means that the NS record was on the record specified. + return map[string][]*DNSResponse{ + "NS": { + { + Type: recordType, + TTL: ttl, + Name: strings.TrimRight(dnsServer, "."), + DNSServer: oldDnsServer, + }, + }, + }, nil + } + + // No DNS records were found. + return map[string][]*DNSResponse{"NS": {}}, nil + } - // Defines the record types. - recordTypes := []string{recordType} - recordTypePacket, ok := godns.StringToType[recordType] + // Try to update the DNS server used here. + if err = setServerAddr(); err != nil { + return nil, err + } + + // Define the record types. + recordTypes := []string{recordType} + if recordType == "ANY" { + // Not many DNS resolvers support this anymore, set it to literally all record types. + recordTypes = []string{"A", "AAAA", "CNAME", "MX", "NS", "PTR", "SOA", "TXT"} + } + recordTypesPacket := make([]uint16, len(recordTypes)) + for i, v := range recordTypes { + packetType, ok := godns.StringToType[v] if !ok { - context.String(400, "Invalid record type") - return + return nil, errors.New("invalid record type: " + v) } - recordTypeIdsOnly := []uint16{recordTypePacket} - recordTypesPacket := []hostnameRecordType{{ - hostname: hostname, - recordType: recordTypePacket, - }} - if recordType == "ANY" { - // Since DNS servers rarely support ANY, we need to manually handle it. - recordTypes = []string{"A", "AAAA", "CNAME", "MX", "NS", "PTR", "SOA", "SRV", "TXT"} - recordTypeIdsOnly = make([]uint16, len(recordTypes)) - recordTypesPacket = make([]hostnameRecordType, len(recordTypes)) - for i, v := range recordTypes { - x, _ := godns.StringToType[v] - recordTypeIdsOnly[i] = x - recordTypesPacket[i] = hostnameRecordType{ - hostname: hostname, - recordType: x, - } - } + recordTypesPacket[i] = packetType + } + + // Defines all DNS responses. + responses := map[string][]*DNSResponse{} + responsesLock := sync.Mutex{} + appendToRecordType := func(recordType string, responseArgs ...*DNSResponse) { + responsesLock.Lock() + defer responsesLock.Unlock() + + // Stops an append with zero items being nil. + a := responses[recordType] + if a == nil { + a = []*DNSResponse{} } - // Handle the creation of all relevant DNS messages if trace is on. - if params.Trace { - // Get all relevant dot split. - dots := strings.Split(hostname, ".") - - // Get the current end index. - currentEndIndex := len(dots) - 2 - - // Defines the current fragment. - currentFragment := "" - - // Go through each end index. - for i := currentEndIndex; i > 0; i-- { - currentFragment = dots[i] + "." + currentFragment - for _, v := range recordTypeIdsOnly { - recordTypesPacket = append(recordTypesPacket, hostnameRecordType{ - hostname: currentFragment, - recordType: v, - }) - } + responses[recordType] = append(a, responseArgs...) + } + eg := errgroup.Group{} + for i, recordLoop := range recordTypes { + // Get all items which may not be thread safe. + recordLoopName := recordLoop + packetType := recordTypesPacket[i] + + // Do the DNS lookup. + eg.Go(func() error { + // Do the DNS lookup. + msg, err := godnsLookup(log, addr, packetType, strings.Join(chunks, ".")+".") + if err != nil { + return err } - } - // Go through each record to make the message. - results := make([]*godns.Msg, len(recordTypesPacket)) - anyQclass := godns.StringToClass["IN"] - wg := errgroup.Group{} - for i, v := range recordTypesPacket { - resultPtr := &results[i] - qtypeAndHostname := v - wg.Go(func() error { - // Make the DNS connection. - dnsServer := cachedDnsServer - if !params.Cache { - dnsServer = "1.1.1.1:53" - } - conn, err := godns.Dial("tcp", dnsServer) - if err != nil { - log.Error("failed to connect to dns server", zap.Error(err)) - return err - } + // Make each response. + dnsResponses := make([]*DNSResponse, 0) + answerIteration: + for _, v := range msg.Answer { + // Handle the various responses. + var data json.RawMessage + originalValue := v + resultDnsHost := dnsServer + parseAnswer: + switch x := v.(type) { + case *godns.CNAME: + if recordLoopName == "CNAME" { + // This is to be expected here since we are looking for CNAME records. + b, _ := json.Marshal(x.Target) + data = b + } else { + // In this situation, the DNS configuration is telling us to look elsewhere. + recursionCount := 0 + for recursionCount < 50 { + // Chunkify the CNAME. + chunkifyReady := x.Target + if strings.HasSuffix(chunkifyReady, ".") { + chunkifyReady = strings.TrimRight(chunkifyReady, ".") + } + cnameChunks := strings.Split(chunkifyReady, ".") - // Defer killing the connection to stop leaks. - defer conn.Close() + // Get the NS host. + nsHost, _, _, err := findNameserverHostname(log, initAddr, cnameChunks) + if err != nil { + return err + } + if nsHost == "" { + // Unable to find NS record. + log.Warn("unable to find NS recordLoopName", zap.String("hostname", strings.Join(cnameChunks, "."))) + continue answerIteration + } - // Create the DNS message. - msg := &godns.Msg{} - msg.Id = godns.Id() - msg.RecursionDesired = true + // Turn that into the address. + rawAddr, err := net.ResolveIPAddr("ip", nsHost) + if err != nil { + return err + } + nsAddr := rawAddr.IP.String() + ":53" - // DNS servers prefer 1 message per request. Make the question. - msg.Question = []godns.Question{{ - Name: qtypeAndHostname.hostname, - Qtype: qtypeAndHostname.recordType, - Qclass: anyQclass, - }} + // Lookup the CNAME's value. + cnameLookupMsg, err := godnsLookup(log, nsAddr, packetType, x.Target) + if err != nil { + return err + } - // Send the DNS message. - err = conn.WriteMsg(msg) - if err != nil { - return &gin.Error{ - Err: fmt.Errorf("failed to perform lookup: %v", err), - Type: gin.ErrorTypePublic, + // If there is no answers, continue the root loop. + if len(cnameLookupMsg.Answer) == 0 { + continue answerIteration + } + + // Check if this contains non-CNAME records. + for _, iface := range cnameLookupMsg.Answer { + switch result := iface.(type) { + case *godns.CNAME: + // Ignore this. + default: + // We are past CNAME's! + v = result + resultDnsHost = nsHost + goto parseAnswer + } + } + + // Set the next CNAME we are parsing. + x = cnameLookupMsg.Answer[0].(*godns.CNAME) + + // Add 1 to the recursion count. + recursionCount++ + } + return fmt.Errorf("recordLoopName type %s for host %s has hit recursion limit", recordLoopName, strings.Join(chunks, ".")) + } + default: + // Get the data from the record. + // Due to the nature of the library, this is sadly a little magical. + reflectValue := reflect.Indirect(reflect.ValueOf(v)) + reflectType := reflectValue.Type() + n := reflectType.NumField() + for i := 0; i < n; i++ { + f := reflectType.Field(i) + if strings.ToUpper(f.Name) == recordLoopName { + // This is the field we want. + var err error + data, err = json.Marshal(reflectValue.FieldByName(f.Name).Interface()) + if err != nil { + return fmt.Errorf("failed to marshal json: %v", err) + } + break + } + } + if data == nil { + // In this situation, we will throw it into the JSON cleanifier. + var err error + data, err = json.Marshal(jsonCleanifier{ + Value: v, + RemoveKeys: []string{"Hdr"}, + }) + if err != nil { + return fmt.Errorf("failed to marshal json: %v", err) + } } } - // Read the DNS response. - msg, err = conn.ReadMsg() + // Handle the priority for MX records. + var preference *uint16 + if mx, ok := v.(*godns.MX); ok { + preference = &mx.Preference + } + + // Make the response. + h := originalValue.Header() + r := &DNSResponse{ + Type: recordLoopName, + TTL: h.Ttl, + Name: strings.TrimRight(h.Name, "."), + Value: data, + Preference: preference, + DNSServer: strings.TrimRight(resultDnsHost, "."), + dnsStringify: v.String, + } + dnsResponses = append(dnsResponses, r) + } + appendToRecordType(recordLoopName, dnsResponses...) + return nil + }) + } + + // Handle any additional recursion. + mapChunks := []map[string][]*DNSResponse{} + if recursive { + mapChunks = make([]map[string][]*DNSResponse, len(chunks)-1) + for i = 1; i < len(chunks); i++ { + mapPtr := &mapChunks[i-1] + x := i + eg.Go(func() error { + remainderChunks := chunks[x:] + map_, err := doDnsLookups(log, oldDnsServer, recordType, false, remainderChunks) if err != nil { - log.Error("failed to read from dns server", zap.Error(err)) return err } - - // Set the pointer to the result and return no errors. - *resultPtr = msg + *mapPtr = map_ return nil }) } + } - // Handle any errors. - if err := wg.Wait(); err != nil { - context.Error(err) - return - } + // Go ahead and run the DNS lookups. + if err = eg.Wait(); err != nil { + return nil, err + } - // Sort the types by alphabetical order. - sort.Strings(recordTypes) + // Add all the map keys found in the right order and later. + for _, map_ := range mapChunks { + for k, v := range map_ { + responses[k] = append(responses[k], v...) + } + } - // Handle formatting the results. - strResponses := []string{} - jsonResponses := map[string][]DNSResponse{} - recordTypesLen := len(recordTypes) - for i, response := range results { - // Get the record type. - recordType = recordTypes[i%recordTypesLen] + // Return all responses. + return responses, nil +} - // Continue if record type is not NS/ANY. - if recordType == "NS" && (originRecordType != "ANY" && originRecordType != "NS") { - continue - } +func dns(g *gin.RouterGroup, log *zap.Logger, dnsServer string) { + g.GET("/:recordType/:hostname", func(context *gin.Context) { + // Defines if this is JSON. + isJson := context.ContentType() == "application/json" - // Get the response from the DNS server. - if response.Answer == nil { - // In the case that this is JSON, we don't want to return a nil array. - if isJson { - if _, ok = jsonResponses[recordType]; !ok { - jsonResponses[recordType] = []DNSResponse{} - } - } + // Bind the params. + var params dnsParams + if err := context.BindQuery(¶ms); err != nil { + if isJson { + context.JSON(400, map[string]string{ + "message": err.Error(), + }) } else { - if isJson { - a := make([]DNSResponse, len(response.Answer)) - for i, v := range response.Answer { - // Get the data from the record. - // Due to the nature of the library, this is sadly a little magical. - var data json.RawMessage - reflectValue := reflect.Indirect(reflect.ValueOf(v)) - reflectType := reflectValue.Type() - n := reflectType.NumField() - for i := 0; i < n; i++ { - f := reflectType.Field(i) - if strings.ToUpper(f.Name) == recordType { - // This is the field we want. - var err error - data, err = json.Marshal(reflectValue.FieldByName(f.Name).Interface()) - if err != nil { - context.Error(fmt.Errorf("failed to marshal json: %v", err)) - return - } - break - } - } - if data == nil { - // In this situation, we will throw it into the JSON cleanifier. - var err error - data, err = json.Marshal(jsonCleanifier{ - Value: v, - RemoveKeys: []string{"Hdr"}, - }) - if err != nil { - context.Error(fmt.Errorf("failed to marshal json: %v", err)) - return - } - } - - // Handle the priority for MX records. - var preference *uint16 - if mx, ok := v.(*godns.MX); ok { - preference = &mx.Preference - } + context.String(400, "unable to parse query params: %s", err.Error()) + } + return + } - // Get the response. - h := v.Header() - a[i] = DNSResponse{ - Type: recordType, - TTL: h.Ttl, - Name: h.Name, - Value: data, - Preference: preference, - } - } - if x, ok := jsonResponses[recordType]; ok { - jsonResponses[recordType] = append(x, a...) - } else { - jsonResponses[recordType] = a - } - } else { - // Use the string representation from the DNS library but remove a few chunks. - for _, v := range response.Answer { - s := strings.SplitN(v.String(), "\t", 4) - strResponses = append(strResponses, s[0]+"\t"+s[3]) - } - } + // Get the type and hostname from the URL. + recordType := context.Param("recordType") + hostname := strings.TrimSuffix(context.Param("hostname"), ".") + chunks := []string{} + for _, v := range strings.Split(hostname, ".") { + if v != "" { + chunks = append(chunks, v) } } + if len(chunks) == 0 { + context.Error(&gin.Error{ + Type: gin.ErrorTypePublic, + Err: errors.New("invalid hostname"), + }) + return + } - // Return the response. + // Do the DNS lookup. + results, err := doDnsLookups(log, dnsServer, recordType, params.Trace, chunks) + if err != nil { + context.Error(&gin.Error{ + Type: gin.ErrorTypePublic, + Err: fmt.Errorf("failed to perform dns lookup: %v", err), + }) + return + } + + // Handle JSON responses. if isJson { - context.JSON(200, jsonResponses) - } else { - context.String(200, strings.Join(strResponses, "\n")) + context.JSON(200, results) + return + } + + // Get the keys and order them. + keys := make([]string, len(results)) + i := 0 + for k := range results { + keys[i] = k + i++ + } + sort.Strings(keys) + + // Formulate the text response. + strResponse := "" + for _, key := range keys { + // Get the slice. + s := results[key] + + // Go through each value. + for _, value := range s { + split := strings.SplitN(value.dnsStringify(), "\t", 4) + if strResponse != "" { + strResponse += "\n" + } + strResponse += split[0] + "\t" + split[3] + } } + context.String(200, strResponse) }) } diff --git a/backend/dns/dns_server.go b/backend/dns/dns_server.go index 6562b5a..9c97315 100644 --- a/backend/dns/dns_server.go +++ b/backend/dns/dns_server.go @@ -3,22 +3,13 @@ package dns import ( "go.uber.org/zap" "os" - "regexp" ) -var portRe = regexp.MustCompile(":[0-9]+$") - // GetCachedDNSServer is used to get the cache DNS server. func GetCachedDNSServer(log *zap.Logger) string { // Handle the environment variable override. s := os.Getenv("DNS_SERVER") if s != "" { - // Check if a port is attached. - if !portRe.MatchString(s) { - s += ":53" - } - - // Very poggers. Return here. return s } @@ -36,8 +27,5 @@ func GetCachedDNSServer(log *zap.Logger) string { panic("no DNS server found") } s = ns[len(ns)-1] - if !portRe.MatchString(s) { - s += ":53" - } return s } diff --git a/backend/frontend.go b/backend/frontend.go index 488435a..042bd56 100644 --- a/backend/frontend.go +++ b/backend/frontend.go @@ -186,8 +186,8 @@ func initFrontend(r *gin.Engine, f fs.FS, logger *zap.Logger) { url: /`) } else { errorFrontend(r, logger, nil, "error reading regions.yml") + return } - return } // Attempt to unmarshal the YAML. diff --git a/frontend/.env.development b/frontend/.env.development index 8c4ad56..761ceda 100644 --- a/frontend/.env.development +++ b/frontend/.env.development @@ -1,5 +1,3 @@ -# Currently using Jake's test deployment of the backend for development -# because it is set up to work with Bird -REACT_APP_BACKEND_LONDON_ORIGIN="https://lon-1.tools.k.io" -REACT_APP_BACKEND_US_EAST_ORIGIN="https://lon-1.tools.k.io" -REACT_APP_BACKEND_US_WEST_ORIGIN="https://lon-1.tools.k.io" +REACT_APP_BACKEND_LONDON_ORIGIN="http://localhost:8080" +REACT_APP_BACKEND_US_EAST_ORIGIN="http://localhost:8080" +REACT_APP_BACKEND_US_WEST_ORIGIN="http://localhost:8080" diff --git a/frontend/src/pages/dns/dns-table.tsx b/frontend/src/pages/dns/dns-table.tsx index 218287f..a4d71df 100644 --- a/frontend/src/pages/dns/dns-table.tsx +++ b/frontend/src/pages/dns/dns-table.tsx @@ -1,4 +1,4 @@ -import { FC } from "react"; +import { FC, useMemo } from "react"; import { Button, @@ -9,6 +9,8 @@ import { PopoverTrigger, Table, Tag, + Box, + Text, Tbody, Td, Th, @@ -28,23 +30,40 @@ type DnsTableProps = { record: DnsResponse[DnsType]; }; -const DnsTableHead: FC = ({ record }) => { - const wideValue = !record.find((item) => typeof item.value !== "string"); +const showPriority = (record: DnsResponse[DnsType]) => + !!record.find((item) => typeof item.priority !== "undefined"); +const DnsTableHead: FC = ({ record }) => { return ( - Type + + Type + - Name + + Name + - TTL + + TTL + - {record.find((item) => typeof item.priority !== "undefined") && ( - Priority + {showPriority(record) && ( + + Priority + )} - + Value @@ -52,7 +71,11 @@ const DnsTableHead: FC = ({ record }) => { ); }; -const DnsTableRow: FC<{ row: DnsResponse[DnsType][number] }> = ({ row }) => { +type DnsTableRowProps = { + row: DnsResponse[DnsType][number]; +}; + +const DnsTableRow: FC = ({ row }) => { const arrowColor = useColorModeValue("gray.200", "gray.900"); return ( @@ -74,13 +97,13 @@ const DnsTableRow: FC<{ row: DnsResponse[DnsType][number] }> = ({ row }) => { )} {typeof row.value === "string" && ( - + {row.value} )} {typeof row.value !== "string" && ( - +