Skip to content

Commit

Permalink
add LookupSRV method to client
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslu committed Nov 3, 2024
1 parent ffdbccf commit 0b492bc
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 17 deletions.
94 changes: 94 additions & 0 deletions client_resolver.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package fastdns

import (
"cmp"
"context"
"encoding/binary"
"net"
"net/netip"
"slices"
)

// AppendLookupNetIP looks up host and appends result to dst using the local resolver.
Expand Down Expand Up @@ -242,3 +244,95 @@ func (c *Client) LookupHTTPS(ctx context.Context, host string) (https []NetHTTPS

return
}

// LookupSRV tries to resolve an SRV query of the given service, protocol, and domain name.
// The proto is "tcp" or "udp". The returned records are sorted by priority and randomized by weight within a priority.
func (c *Client) LookupSRV(ctx context.Context, service, proto, name string) (target string, srvs []*net.SRV, err error) {
if service == "" && proto == "" {
target = name
} else {
target = "_" + service + "._" + proto + "." + name
}

req, resp := AcquireMessage(), AcquireMessage()
defer ReleaseMessage(resp)
defer ReleaseMessage(req)

req.SetRequestQuestion(target, TypeSRV, ClassINET)

err = c.Exchange(ctx, req, resp)
if err != nil {
return
}

var buf [256]byte
for r := range resp.Records {
switch r.Type {
case TypeSRV:
if len(r.Data) < 8 {
err = ErrInvalidAnswer
break
}
srvs = append(srvs, &net.SRV{
Target: string(resp.DecodeName(buf[:0], r.Data[6:])),
Port: binary.BigEndian.Uint16(r.Data[4:]),
Priority: binary.BigEndian.Uint16(r.Data[0:]),
Weight: binary.BigEndian.Uint16(r.Data[2:]),
})
default:
err = ErrInvalidAnswer
}
}

if len(srvs) > 1 {
byPriorityWeight(srvs).sort()
}

return
}

// Copy from https://github.com/golang/go/blob/master/src/net/dnsclient.go
// byPriorityWeight sorts SRV records by ascending priority and weight.
type byPriorityWeight []*net.SRV

// shuffleByWeight shuffles SRV records by weight using the algorithm
// described in RFC 2782.
func (addrs byPriorityWeight) shuffleByWeight() {
sum := 0
for _, addr := range addrs {
sum += int(addr.Weight)
}
for sum > 0 && len(addrs) > 1 {
s := 0
n := int(cheaprandn(uint32(sum)))
for i := range addrs {
s += int(addrs[i].Weight)
if s > n {
if i > 0 {
addrs[0], addrs[i] = addrs[i], addrs[0]
}
break
}
}
sum -= int(addrs[0].Weight)
addrs = addrs[1:]
}
}

// sort reorders SRV records as specified in RFC 2782.
func (addrs byPriorityWeight) sort() {
slices.SortFunc(addrs, func(a, b *net.SRV) int {
if r := cmp.Compare(a.Priority, b.Priority); r != 0 {
return r
}
return cmp.Compare(a.Weight, b.Weight)
})
i := 0
for j := 1; j < len(addrs); j++ {
if addrs[i].Priority != addrs[j].Priority {
addrs[i:j].shuffleByWeight()
i = j
}
}
addrs[i:].shuffleByWeight()
}
47 changes: 30 additions & 17 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ func TestClientExchange(t *testing.T) {
}
}

func deref(value any) any {
v := reflect.ValueOf(value)
if v.Kind() != reflect.Slice {
return v
}
result := make([]any, v.Len())
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if elem.Kind() == reflect.Ptr {
result[i] = elem.Elem().Interface()
} else {
result[i] = elem.Interface()
}
}
return result
}

func TestClientLookup(t *testing.T) {
var cases = []struct {
Host string
Expand Down Expand Up @@ -105,23 +122,6 @@ func TestClientLookup(t *testing.T) {
},
}

deref := func(value any) any {
v := reflect.ValueOf(value)
if v.Kind() != reflect.Slice {
return v
}
result := make([]any, v.Len())
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if elem.Kind() == reflect.Ptr {
result[i] = elem.Elem().Interface()
} else {
result[i] = elem.Interface()
}
}
return result
}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

Expand Down Expand Up @@ -172,6 +172,19 @@ func TestClientLookupNetIP(t *testing.T) {
t.Logf("dns_server=%+v LookupNetIP(%#v) return %+v", client.Addr, host, ips)
}

func TestClientLookupSRV(t *testing.T) {
client := &Client{
Addr: "1.1.1.1:53",
Timeout: 1 * time.Second,
}

cname, ips, err := client.LookupSRV(context.Background(), "xmpp-client", "tcp", "jabber.org")
if err != nil {
t.Errorf("dns_server=%+v LookupSRV(\"_xmpp-client._tcp.jabber.org\") error: %+v\n", client.Addr, err)
}
t.Logf("dns_server=%+v LookupSRV(\"_xmpp-client._tcp.jabber.org\") return %+v %+v", client.Addr, cname, deref(ips))
}

func BenchmarkResolverPureGo(b *testing.B) {
resolver := net.Resolver{PreferGo: true}

Expand Down

0 comments on commit 0b492bc

Please sign in to comment.