From 3974c99436697eb1661750173661eda91b8f8d5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 14 Feb 2024 20:20:29 +0800 Subject: [PATCH] Add rejected DNS response cache support --- client.go | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index e35a8ef..8654e10 100644 --- a/client.go +++ b/client.go @@ -19,20 +19,28 @@ import ( const DefaultTTL = 600 var ( - ErrNoRawSupport = E.New("no raw query support by current transport") - ErrNotCached = E.New("not cached") - ErrResponseRejected = E.New("response rejected") + ErrNoRawSupport = E.New("no raw query support by current transport") + ErrNotCached = E.New("not cached") + ErrResponseRejected = E.New("response rejected") + ErrResponseRejectedCached = E.Extend(ErrResponseRejected, "cached") ) type Client struct { disableCache bool disableExpire bool independentCache bool + rdrc RDRCStore logger logger.ContextLogger cache *cache.LruCache[dns.Question, *dns.Msg] transportCache *cache.LruCache[transportCacheKey, *dns.Msg] } +type RDRCStore interface { + LoadRDRC(transportName string, qName string) (rejected bool) + SaveRDRC(transportName string, qName string) error + SaveRDRCAsync(transportName string, qName string, logger logger.Logger) +} + type transportCacheKey struct { dns.Question transportName string @@ -42,6 +50,7 @@ type ClientOptions struct { DisableCache bool DisableExpire bool IndependentCache bool + RDRC RDRCStore Logger logger.ContextLogger } @@ -50,6 +59,7 @@ func NewClient(options ClientOptions) *Client { disableCache: options.DisableCache, disableExpire: options.DisableExpire, independentCache: options.IndependentCache, + rdrc: options.RDRC, logger: options.Logger, } if !client.disableCache { @@ -121,11 +131,20 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp if loaded { SetClientSubnet(message, clientSubnet, true) } + if responseChecker != nil && c.rdrc != nil { + rejected := c.rdrc.LoadRDRC(transport.Name(), question.Name) + if rejected { + return nil, ErrResponseRejectedCached + } + } response, err := transport.Exchange(ctx, message) if err != nil { return nil, err } if responseChecker != nil && !responseChecker(response) { + if c.rdrc != nil { + c.rdrc.SaveRDRCAsync(transport.Name(), question.Name, c.logger) + } return response, ErrResponseRejected } var timeToLive int @@ -238,12 +257,21 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor } } } + if responseChecker != nil && c.rdrc != nil { + rejected := c.rdrc.LoadRDRC(transport.Name(), dnsName) + if rejected { + return nil, ErrResponseRejectedCached + } + } var rCode int response, err := transport.Lookup(ctx, domain, strategy) if err != nil { return nil, wrapError(err) } if responseChecker != nil && !responseChecker(response) { + if c.rdrc != nil { + c.rdrc.SaveRDRCAsync(transport.Name(), dnsName, c.logger) + } return response, ErrResponseRejected } header := dns.MsgHdr{