diff --git a/.gitignore b/.gitignore index f7f8ac3..f1298ea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /.idea/ /vendor/ +.DS_Store diff --git a/client.go b/client.go index 23b5f58..79a067a 100644 --- a/client.go +++ b/client.go @@ -89,11 +89,11 @@ func (c *Client) Start() { } } -func (c *Client) Exchange(ctx context.Context, transport Transport, message *dns.Msg, strategy DomainStrategy) (*dns.Msg, error) { - return c.ExchangeWithResponseCheck(ctx, transport, message, strategy, nil) +func (c *Client) Exchange(ctx context.Context, transport Transport, message *dns.Msg, options QueryOptions) (*dns.Msg, error) { + return c.ExchangeWithResponseCheck(ctx, transport, message, options, nil) } -func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transport, message *dns.Msg, strategy DomainStrategy, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) { +func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transport, message *dns.Msg, options QueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) { if len(message.Question) == 0 { if c.logger != nil { c.logger.WarnContext(ctx, "bad question size: ", len(message.Question)) @@ -109,15 +109,14 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp return &responseMessage, nil } question := message.Question[0] - clientSubnet, clientSubnetLoaded := ClientSubnetFromContext(ctx) - if clientSubnetLoaded { - message = SetClientSubnet(message, clientSubnet, true) + if options.ClientSubnet.IsValid() { + message = SetClientSubnet(message, options.ClientSubnet, true) } isSimpleRequest := len(message.Question) == 1 && len(message.Ns) == 0 && len(message.Extra) == 0 && - !clientSubnetLoaded - disableCache := !isSimpleRequest || c.disableCache || DisableCacheFromContext(ctx) + !options.ClientSubnet.IsValid() + disableCache := !isSimpleRequest || c.disableCache || options.DisableCache if !disableCache { response, ttl := c.loadResponse(question, transport) if response != nil { @@ -126,7 +125,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp return response, nil } } - if question.Qtype == dns.TypeA && strategy == DomainStrategyUseIPv6 || question.Qtype == dns.TypeAAAA && strategy == DomainStrategyUseIPv4 { + if question.Qtype == dns.TypeA && options.Strategy == DomainStrategyUseIPv6 || question.Qtype == dns.TypeAAAA && options.Strategy == DomainStrategyUseIPv4 { responseMessage := dns.Msg{ MsgHdr: dns.MsgHdr{ Id: message.Id, @@ -142,7 +141,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp } if !transport.Raw() { if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA { - return c.exchangeToLookup(ctx, transport, message, question) + return c.exchangeToLookup(ctx, transport, message, question, options) } return nil, ErrNoRawSupport } @@ -171,7 +170,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp return response, ErrResponseRejected } if question.Qtype == dns.TypeHTTPS { - if strategy == DomainStrategyUseIPv4 || strategy == DomainStrategyUseIPv6 { + if options.Strategy == DomainStrategyUseIPv4 || options.Strategy == DomainStrategyUseIPv6 { for _, rr := range response.Answer { https, isHTTPS := rr.(*dns.HTTPS) if !isHTTPS { @@ -179,7 +178,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp } content := https.SVCB content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool { - if strategy == DomainStrategyUseIPv4 { + if options.Strategy == DomainStrategyUseIPv4 { return it.Key() != dns.SVCB_IPV6HINT } else { return it.Key() != dns.SVCB_IPV4HINT @@ -197,8 +196,8 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp } } } - if rewriteTTL, loaded := RewriteTTLFromContext(ctx); loaded { - timeToLive = int(rewriteTTL) + if options.RewriteTTL != nil { + timeToLive = int(*options.RewriteTTL) } for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} { for _, record := range recordList { @@ -213,26 +212,26 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp return response, err } -func (c *Client) Lookup(ctx context.Context, transport Transport, domain string, strategy DomainStrategy) ([]netip.Addr, error) { - return c.LookupWithResponseCheck(ctx, transport, domain, strategy, nil) +func (c *Client) Lookup(ctx context.Context, transport Transport, domain string, options QueryOptions) ([]netip.Addr, error) { + return c.LookupWithResponseCheck(ctx, transport, domain, options, nil) } -func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transport, domain string, strategy DomainStrategy, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { +func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transport, domain string, options QueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { if dns.IsFqdn(domain) { domain = domain[:len(domain)-1] } dnsName := dns.Fqdn(domain) if transport.Raw() { - if strategy == DomainStrategyUseIPv4 { - return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, strategy, responseChecker) - } else if strategy == DomainStrategyUseIPv6 { - return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, strategy, responseChecker) + if options.Strategy == DomainStrategyUseIPv4 { + return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker) + } else if options.Strategy == DomainStrategyUseIPv6 { + return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker) } var response4 []netip.Addr var response6 []netip.Addr var group task.Group group.Append("exchange4", func(ctx context.Context) error { - response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, strategy, responseChecker) + response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker) if err != nil { return err } @@ -240,7 +239,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor return nil }) group.Append("exchange6", func(ctx context.Context) error { - response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, strategy, responseChecker) + response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker) if err != nil { return err } @@ -251,11 +250,11 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor if len(response4) == 0 && len(response6) == 0 { return nil, err } - return sortAddresses(response4, response6, strategy), nil + return sortAddresses(response4, response6, options.Strategy), nil } - disableCache := c.disableCache || DisableCacheFromContext(ctx) + disableCache := c.disableCache || options.DisableCache if !disableCache { - if strategy == DomainStrategyUseIPv4 { + if options.Strategy == DomainStrategyUseIPv4 { response, err := c.questionCache(dns.Question{ Name: dnsName, Qtype: dns.TypeA, @@ -264,7 +263,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor if err != ErrNotCached { return response, err } - } else if strategy == DomainStrategyUseIPv6 { + } else if options.Strategy == DomainStrategyUseIPv6 { response, err := c.questionCache(dns.Question{ Name: dnsName, Qtype: dns.TypeAAAA, @@ -285,16 +284,16 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor Qclass: dns.ClassINET, }, transport) if len(response4) > 0 || len(response6) > 0 { - return sortAddresses(response4, response6, strategy), nil + return sortAddresses(response4, response6, options.Strategy), nil } } } if responseChecker != nil && c.rdrc != nil { var rejected bool - if strategy != DomainStrategyUseIPv6 { + if options.Strategy != DomainStrategyUseIPv6 { rejected = c.rdrc.LoadRDRC(transport.Name(), dnsName, dns.TypeA) } - if !rejected && strategy != DomainStrategyUseIPv4 { + if !rejected && options.Strategy != DomainStrategyUseIPv4 { rejected = c.rdrc.LoadRDRC(transport.Name(), dnsName, dns.TypeAAAA) } if rejected { @@ -303,7 +302,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor } ctx, cancel := context.WithTimeout(ctx, c.timeout) var rCode int - response, err := transport.Lookup(ctx, domain, strategy) + response, err := transport.Lookup(ctx, domain, options.Strategy) cancel() if err != nil { return nil, wrapError(err) @@ -329,12 +328,12 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor } if !disableCache { var timeToLive uint32 - if rewriteTTL, loaded := RewriteTTLFromContext(ctx); loaded { - timeToLive = rewriteTTL + if options.RewriteTTL != nil { + timeToLive = *options.RewriteTTL } else { timeToLive = DefaultTTL } - if strategy != DomainStrategyUseIPv6 { + if options.Strategy != DomainStrategyUseIPv6 { question4 := dns.Question{ Name: dnsName, Qtype: dns.TypeA, @@ -362,7 +361,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor } c.storeCache(transport, question4, message4, int(timeToLive)) } - if strategy != DomainStrategyUseIPv4 { + if options.Strategy != DomainStrategyUseIPv4 { question6 := dns.Question{ Name: dnsName, Qtype: dns.TypeAAAA, @@ -404,11 +403,7 @@ func (c *Client) ClearCache() { } func (c *Client) LookupCache(ctx context.Context, domain string, strategy DomainStrategy) ([]netip.Addr, bool) { - if c.independentCache { - return nil, false - } - disableCache := c.disableCache || DisableCacheFromContext(ctx) - if disableCache { + if c.disableCache || c.independentCache { return nil, false } if dns.IsFqdn(domain) { @@ -452,19 +447,10 @@ func (c *Client) LookupCache(ctx context.Context, domain string, strategy Domain } func (c *Client) ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) { - if c.independentCache || len(message.Question) != 1 { + if c.disableCache || c.independentCache || len(message.Question) != 1 { return nil, false } question := message.Question[0] - _, clientSubnetLoaded := transportNameFromContext(ctx) - isSimpleRequest := len(message.Question) == 1 && - len(message.Ns) == 0 && - len(message.Extra) == 0 && - !clientSubnetLoaded - disableCache := !isSimpleRequest || c.disableCache || DisableCacheFromContext(ctx) - if disableCache { - return nil, false - } response, ttl := c.loadResponse(question, nil) if response == nil { return nil, false @@ -508,15 +494,14 @@ func (c *Client) storeCache(transport Transport, question dns.Question, message } } -func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, message *dns.Msg, question dns.Question) (*dns.Msg, error) { +func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, message *dns.Msg, question dns.Question, options QueryOptions) (*dns.Msg, error) { domain := question.Name - var strategy DomainStrategy if question.Qtype == dns.TypeA { - strategy = DomainStrategyUseIPv4 + options.Strategy = DomainStrategyUseIPv4 } else { - strategy = DomainStrategyUseIPv6 + options.Strategy = DomainStrategyUseIPv6 } - result, err := c.Lookup(ctx, transport, domain, strategy) + result, err := c.Lookup(ctx, transport, domain, options) if err != nil { return nil, wrapError(err) } @@ -529,8 +514,8 @@ func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, mess Question: message.Question, } var timeToLive uint32 - if rewriteTTL, loaded := RewriteTTLFromContext(ctx); loaded { - timeToLive = rewriteTTL + if options.RewriteTTL != nil { + timeToLive = *options.RewriteTTL } else { timeToLive = DefaultTTL } @@ -563,13 +548,13 @@ func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, mess return &response, nil } -func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name string, qType uint16, strategy DomainStrategy, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { +func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name string, qType uint16, options QueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { question := dns.Question{ Name: name, Qtype: qType, Qclass: dns.ClassINET, } - disableCache := c.disableCache || DisableCacheFromContext(ctx) + disableCache := c.disableCache || options.DisableCache if !disableCache { cachedAddresses, err := c.questionCache(question, transport) if err != ErrNotCached { @@ -587,7 +572,7 @@ func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name err error ) if responseChecker != nil { - response, err = c.ExchangeWithResponseCheck(ctx, transport, &message, strategy, func(response *dns.Msg) bool { + response, err = c.ExchangeWithResponseCheck(ctx, transport, &message, options, func(response *dns.Msg) bool { addresses, addrErr := MessageToAddresses(response) if addrErr != nil { return false @@ -595,7 +580,7 @@ func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name return responseChecker(addresses) }) } else { - response, err = c.Exchange(ctx, transport, &message, strategy) + response, err = c.Exchange(ctx, transport, &message, options) } if err != nil { return nil, err @@ -718,3 +703,14 @@ func wrapError(err error) error { } return err } + +type transportKey struct{} + +func contextWithTransportName(ctx context.Context, transportName string) context.Context { + return context.WithValue(ctx, transportKey{}, transportName) +} + +func transportNameFromContext(ctx context.Context) (string, bool) { + value, loaded := ctx.Value(transportKey{}).(string) + return value, loaded +} diff --git a/client_options.go b/client_options.go new file mode 100644 index 0000000..1eccb26 --- /dev/null +++ b/client_options.go @@ -0,0 +1,10 @@ +package dns + +import "net/netip" + +type QueryOptions struct { + Strategy DomainStrategy + DisableCache bool + RewriteTTL *uint32 + ClientSubnet netip.Prefix +} diff --git a/dialer.go b/dialer.go index 23a89ab..65ac5fe 100644 --- a/dialer.go +++ b/dialer.go @@ -25,7 +25,9 @@ func (d *DialerWrapper) DialContext(ctx context.Context, network string, destina if destination.IsIP() { return d.dialer.DialContext(ctx, network, destination) } - addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, d.strategy) + addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, QueryOptions{ + Strategy: d.strategy, + }) if err != nil { return nil, err } @@ -36,7 +38,9 @@ func (d *DialerWrapper) ListenPacket(ctx context.Context, destination M.Socksadd if destination.IsIP() { return d.dialer.ListenPacket(ctx, destination) } - addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, d.strategy) + addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, QueryOptions{ + Strategy: d.strategy, + }) if err != nil { return nil, err } diff --git a/extensions.go b/extensions.go deleted file mode 100644 index 86bbbd5..0000000 --- a/extensions.go +++ /dev/null @@ -1,56 +0,0 @@ -package dns - -import ( - "context" - "net/netip" -) - -type disableCacheKey struct{} - -func ContextWithDisableCache(ctx context.Context, val bool) context.Context { - return context.WithValue(ctx, (*disableCacheKey)(nil), val) -} - -func DisableCacheFromContext(ctx context.Context) bool { - val := ctx.Value((*disableCacheKey)(nil)) - if val == nil { - return false - } - return val.(bool) -} - -type rewriteTTLKey struct{} - -func ContextWithRewriteTTL(ctx context.Context, val uint32) context.Context { - return context.WithValue(ctx, (*rewriteTTLKey)(nil), val) -} - -func RewriteTTLFromContext(ctx context.Context) (uint32, bool) { - val := ctx.Value((*rewriteTTLKey)(nil)) - if val == nil { - return 0, false - } - return val.(uint32), true -} - -type transportKey struct{} - -func contextWithTransportName(ctx context.Context, transportName string) context.Context { - return context.WithValue(ctx, transportKey{}, transportName) -} - -func transportNameFromContext(ctx context.Context) (string, bool) { - value, loaded := ctx.Value(transportKey{}).(string) - return value, loaded -} - -type clientSubnetKey struct{} - -func ContextWithClientSubnet(ctx context.Context, clientSubnet netip.Prefix) context.Context { - return context.WithValue(ctx, clientSubnetKey{}, clientSubnet) -} - -func ClientSubnetFromContext(ctx context.Context) (netip.Prefix, bool) { - clientSubnet, ok := ctx.Value(clientSubnetKey{}).(netip.Prefix) - return clientSubnet, ok -}