Skip to content

Commit

Permalink
Move ctx options to struct
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Oct 21, 2024
1 parent af17d0a commit a59e0fb
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 119 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/.idea/
/vendor/
.DS_Store
118 changes: 57 additions & 61 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ func (c *Client) Start() {
}
}

func (c *Client) Exchange(ctx context.Context, transport Transport, message *dns.Msg, strategy DomainStrategy) (*dns.Msg, error) {
return c.ExchangeWithResponseCheck(ctx, transport, message, strategy, nil)
func (c *Client) Exchange(ctx context.Context, transport Transport, message *dns.Msg, options QueryOptions) (*dns.Msg, error) {
return c.ExchangeWithResponseCheck(ctx, transport, message, options, nil)
}

func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transport, message *dns.Msg, strategy DomainStrategy, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) {
func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transport, message *dns.Msg, options QueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) {
if len(message.Question) == 0 {
if c.logger != nil {
c.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
Expand All @@ -109,15 +109,14 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp
return &responseMessage, nil
}
question := message.Question[0]
clientSubnet, clientSubnetLoaded := ClientSubnetFromContext(ctx)
if clientSubnetLoaded {
message = SetClientSubnet(message, clientSubnet, true)
if options.ClientSubnet.IsValid() {
message = SetClientSubnet(message, options.ClientSubnet, true)
}
isSimpleRequest := len(message.Question) == 1 &&
len(message.Ns) == 0 &&
len(message.Extra) == 0 &&
!clientSubnetLoaded
disableCache := !isSimpleRequest || c.disableCache || DisableCacheFromContext(ctx)
!options.ClientSubnet.IsValid()
disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
if !disableCache {
response, ttl := c.loadResponse(question, transport)
if response != nil {
Expand All @@ -126,7 +125,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp
return response, nil
}
}
if question.Qtype == dns.TypeA && strategy == DomainStrategyUseIPv6 || question.Qtype == dns.TypeAAAA && strategy == DomainStrategyUseIPv4 {
if question.Qtype == dns.TypeA && options.Strategy == DomainStrategyUseIPv6 || question.Qtype == dns.TypeAAAA && options.Strategy == DomainStrategyUseIPv4 {
responseMessage := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: message.Id,
Expand All @@ -142,7 +141,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp
}
if !transport.Raw() {
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
return c.exchangeToLookup(ctx, transport, message, question)
return c.exchangeToLookup(ctx, transport, message, question, options)
}
return nil, ErrNoRawSupport
}
Expand Down Expand Up @@ -171,15 +170,15 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp
return response, ErrResponseRejected
}
if question.Qtype == dns.TypeHTTPS {
if strategy == DomainStrategyUseIPv4 || strategy == DomainStrategyUseIPv6 {
if options.Strategy == DomainStrategyUseIPv4 || options.Strategy == DomainStrategyUseIPv6 {
for _, rr := range response.Answer {
https, isHTTPS := rr.(*dns.HTTPS)
if !isHTTPS {
continue
}
content := https.SVCB
content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool {
if strategy == DomainStrategyUseIPv4 {
if options.Strategy == DomainStrategyUseIPv4 {
return it.Key() != dns.SVCB_IPV6HINT
} else {
return it.Key() != dns.SVCB_IPV4HINT
Expand All @@ -197,8 +196,8 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp
}
}
}
if rewriteTTL, loaded := RewriteTTLFromContext(ctx); loaded {
timeToLive = int(rewriteTTL)
if options.RewriteTTL != nil {
timeToLive = int(*options.RewriteTTL)
}
for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} {
for _, record := range recordList {
Expand All @@ -213,34 +212,34 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp
return response, err
}

func (c *Client) Lookup(ctx context.Context, transport Transport, domain string, strategy DomainStrategy) ([]netip.Addr, error) {
return c.LookupWithResponseCheck(ctx, transport, domain, strategy, nil)
func (c *Client) Lookup(ctx context.Context, transport Transport, domain string, options QueryOptions) ([]netip.Addr, error) {
return c.LookupWithResponseCheck(ctx, transport, domain, options, nil)
}

func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transport, domain string, strategy DomainStrategy, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transport, domain string, options QueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
if dns.IsFqdn(domain) {
domain = domain[:len(domain)-1]
}
dnsName := dns.Fqdn(domain)
if transport.Raw() {
if strategy == DomainStrategyUseIPv4 {
return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, strategy, responseChecker)
} else if strategy == DomainStrategyUseIPv6 {
return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, strategy, responseChecker)
if options.Strategy == DomainStrategyUseIPv4 {
return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker)
} else if options.Strategy == DomainStrategyUseIPv6 {
return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker)
}
var response4 []netip.Addr
var response6 []netip.Addr
var group task.Group
group.Append("exchange4", func(ctx context.Context) error {
response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, strategy, responseChecker)
response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker)
if err != nil {
return err
}
response4 = response
return nil
})
group.Append("exchange6", func(ctx context.Context) error {
response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, strategy, responseChecker)
response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker)
if err != nil {
return err
}
Expand All @@ -251,11 +250,11 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor
if len(response4) == 0 && len(response6) == 0 {
return nil, err
}
return sortAddresses(response4, response6, strategy), nil
return sortAddresses(response4, response6, options.Strategy), nil
}
disableCache := c.disableCache || DisableCacheFromContext(ctx)
disableCache := c.disableCache || options.DisableCache
if !disableCache {
if strategy == DomainStrategyUseIPv4 {
if options.Strategy == DomainStrategyUseIPv4 {
response, err := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Expand All @@ -264,7 +263,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor
if err != ErrNotCached {
return response, err
}
} else if strategy == DomainStrategyUseIPv6 {
} else if options.Strategy == DomainStrategyUseIPv6 {
response, err := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Expand All @@ -285,16 +284,16 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor
Qclass: dns.ClassINET,
}, transport)
if len(response4) > 0 || len(response6) > 0 {
return sortAddresses(response4, response6, strategy), nil
return sortAddresses(response4, response6, options.Strategy), nil
}
}
}
if responseChecker != nil && c.rdrc != nil {
var rejected bool
if strategy != DomainStrategyUseIPv6 {
if options.Strategy != DomainStrategyUseIPv6 {
rejected = c.rdrc.LoadRDRC(transport.Name(), dnsName, dns.TypeA)
}
if !rejected && strategy != DomainStrategyUseIPv4 {
if !rejected && options.Strategy != DomainStrategyUseIPv4 {
rejected = c.rdrc.LoadRDRC(transport.Name(), dnsName, dns.TypeAAAA)
}
if rejected {
Expand All @@ -303,7 +302,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor
}
ctx, cancel := context.WithTimeout(ctx, c.timeout)
var rCode int
response, err := transport.Lookup(ctx, domain, strategy)
response, err := transport.Lookup(ctx, domain, options.Strategy)
cancel()
if err != nil {
return nil, wrapError(err)
Expand All @@ -329,12 +328,12 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor
}
if !disableCache {
var timeToLive uint32
if rewriteTTL, loaded := RewriteTTLFromContext(ctx); loaded {
timeToLive = rewriteTTL
if options.RewriteTTL != nil {
timeToLive = *options.RewriteTTL
} else {
timeToLive = DefaultTTL
}
if strategy != DomainStrategyUseIPv6 {
if options.Strategy != DomainStrategyUseIPv6 {
question4 := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Expand Down Expand Up @@ -362,7 +361,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor
}
c.storeCache(transport, question4, message4, int(timeToLive))
}
if strategy != DomainStrategyUseIPv4 {
if options.Strategy != DomainStrategyUseIPv4 {
question6 := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Expand Down Expand Up @@ -404,11 +403,7 @@ func (c *Client) ClearCache() {
}

func (c *Client) LookupCache(ctx context.Context, domain string, strategy DomainStrategy) ([]netip.Addr, bool) {
if c.independentCache {
return nil, false
}
disableCache := c.disableCache || DisableCacheFromContext(ctx)
if disableCache {
if c.disableCache || c.independentCache {
return nil, false
}
if dns.IsFqdn(domain) {
Expand Down Expand Up @@ -452,19 +447,10 @@ func (c *Client) LookupCache(ctx context.Context, domain string, strategy Domain
}

func (c *Client) ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) {
if c.independentCache || len(message.Question) != 1 {
if c.disableCache || c.independentCache || len(message.Question) != 1 {
return nil, false
}
question := message.Question[0]
_, clientSubnetLoaded := transportNameFromContext(ctx)
isSimpleRequest := len(message.Question) == 1 &&
len(message.Ns) == 0 &&
len(message.Extra) == 0 &&
!clientSubnetLoaded
disableCache := !isSimpleRequest || c.disableCache || DisableCacheFromContext(ctx)
if disableCache {
return nil, false
}
response, ttl := c.loadResponse(question, nil)
if response == nil {
return nil, false
Expand Down Expand Up @@ -508,15 +494,14 @@ func (c *Client) storeCache(transport Transport, question dns.Question, message
}
}

func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, message *dns.Msg, question dns.Question) (*dns.Msg, error) {
func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, message *dns.Msg, question dns.Question, options QueryOptions) (*dns.Msg, error) {
domain := question.Name
var strategy DomainStrategy
if question.Qtype == dns.TypeA {
strategy = DomainStrategyUseIPv4
options.Strategy = DomainStrategyUseIPv4
} else {
strategy = DomainStrategyUseIPv6
options.Strategy = DomainStrategyUseIPv6
}
result, err := c.Lookup(ctx, transport, domain, strategy)
result, err := c.Lookup(ctx, transport, domain, options)
if err != nil {
return nil, wrapError(err)
}
Expand All @@ -529,8 +514,8 @@ func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, mess
Question: message.Question,
}
var timeToLive uint32
if rewriteTTL, loaded := RewriteTTLFromContext(ctx); loaded {
timeToLive = rewriteTTL
if options.RewriteTTL != nil {
timeToLive = *options.RewriteTTL
} else {
timeToLive = DefaultTTL
}
Expand Down Expand Up @@ -563,13 +548,13 @@ func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, mess
return &response, nil
}

func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name string, qType uint16, strategy DomainStrategy, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name string, qType uint16, options QueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
question := dns.Question{
Name: name,
Qtype: qType,
Qclass: dns.ClassINET,
}
disableCache := c.disableCache || DisableCacheFromContext(ctx)
disableCache := c.disableCache || options.DisableCache
if !disableCache {
cachedAddresses, err := c.questionCache(question, transport)
if err != ErrNotCached {
Expand All @@ -587,15 +572,15 @@ func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name
err error
)
if responseChecker != nil {
response, err = c.ExchangeWithResponseCheck(ctx, transport, &message, strategy, func(response *dns.Msg) bool {
response, err = c.ExchangeWithResponseCheck(ctx, transport, &message, options, func(response *dns.Msg) bool {
addresses, addrErr := MessageToAddresses(response)
if addrErr != nil {
return false
}
return responseChecker(addresses)
})
} else {
response, err = c.Exchange(ctx, transport, &message, strategy)
response, err = c.Exchange(ctx, transport, &message, options)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -718,3 +703,14 @@ func wrapError(err error) error {
}
return err
}

type transportKey struct{}

func contextWithTransportName(ctx context.Context, transportName string) context.Context {
return context.WithValue(ctx, transportKey{}, transportName)
}

func transportNameFromContext(ctx context.Context) (string, bool) {
value, loaded := ctx.Value(transportKey{}).(string)
return value, loaded
}
10 changes: 10 additions & 0 deletions client_options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package dns

import "net/netip"

type QueryOptions struct {
Strategy DomainStrategy
DisableCache bool
RewriteTTL *uint32
ClientSubnet netip.Prefix
}
8 changes: 6 additions & 2 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ func (d *DialerWrapper) DialContext(ctx context.Context, network string, destina
if destination.IsIP() {
return d.dialer.DialContext(ctx, network, destination)
}
addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, d.strategy)
addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, QueryOptions{
Strategy: d.strategy,
})
if err != nil {
return nil, err
}
Expand All @@ -36,7 +38,9 @@ func (d *DialerWrapper) ListenPacket(ctx context.Context, destination M.Socksadd
if destination.IsIP() {
return d.dialer.ListenPacket(ctx, destination)
}
addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, d.strategy)
addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, QueryOptions{
Strategy: d.strategy,
})
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit a59e0fb

Please sign in to comment.