From c20ce73a459b0dc0c6f0dce586c705a9c5a86cfc Mon Sep 17 00:00:00 2001 From: Matt Dainty Date: Sun, 1 Oct 2023 23:58:38 +0100 Subject: [PATCH] refactor: Minimise duplicated client TSIG code --- gss/apcera.go | 43 +------------------- gss/client.go | 106 ++++++++++++++++++++++++++++++++++++++++++++++++++ gss/gokrb5.go | 43 +------------------- gss/gss.go | 49 ----------------------- gss/sspi.go | 56 +++----------------------- 5 files changed, 115 insertions(+), 182 deletions(-) create mode 100644 gss/client.go diff --git a/gss/apcera.go b/gss/apcera.go index f9cd6ce..ab10358 100644 --- a/gss/apcera.go +++ b/gss/apcera.go @@ -71,24 +71,7 @@ func (c *Client) Close() error { return multierror.Append(c.close(), c.lib.Unload()) } -// Generate generates the TSIG MAC based on the established context. -// It is called with the bytes of the DNS message, and the partial TSIG -// record containing the algorithm and name which is the negotiated TKEY -// for this context. -// It returns the bytes for the TSIG MAC and any error that occurred. -func (c *Client) Generate(msg []byte, t *dns.TSIG) (b []byte, err error) { - if dns.CanonicalName(t.Algorithm) != tsig.GSS { - return nil, dns.ErrKeyAlg - } - - c.m.RLock() - defer c.m.RUnlock() - - ctx, ok := c.ctx[t.Hdr.Name] - if !ok { - return nil, dns.ErrSecret - } - +func (c *Client) generate(ctx *gssapi.CtxId, msg []byte) ([]byte, error) { message, err := c.lib.MakeBufferBytes(msg) if err != nil { return nil, err @@ -110,24 +93,7 @@ func (c *Client) Generate(msg []byte, t *dns.TSIG) (b []byte, err error) { return token.Bytes(), nil } -// Verify verifies the TSIG MAC based on the established context. -// It is called with the bytes of the DNS message, and the TSIG record -// containing the algorithm, MAC, and name which is the negotiated TKEY -// for this context. -// It returns any error that occurred. -func (c *Client) Verify(stripped []byte, t *dns.TSIG) (err error) { - if dns.CanonicalName(t.Algorithm) != tsig.GSS { - return dns.ErrKeyAlg - } - - c.m.RLock() - defer c.m.RUnlock() - - ctx, ok := c.ctx[t.Hdr.Name] - if !ok { - return dns.ErrSecret - } - +func (c *Client) verify(ctx *gssapi.CtxId, stripped, mac []byte) error { // Turn the TSIG-stripped message bytes into a *gssapi.Buffer message, err := c.lib.MakeBufferBytes(stripped) if err != nil { @@ -138,11 +104,6 @@ func (c *Client) Verify(stripped []byte, t *dns.TSIG) (err error) { err = multierror.Append(err, message.Release()).ErrorOrNil() }() - mac, err := hex.DecodeString(t.MAC) - if err != nil { - return err - } - // Turn the TSIG MAC bytes into a *gssapi.Buffer token, err := c.lib.MakeBufferBytes(mac) if err != nil { diff --git a/gss/client.go b/gss/client.go new file mode 100644 index 0000000..69398a6 --- /dev/null +++ b/gss/client.go @@ -0,0 +1,106 @@ +package gss + +import ( + "encoding/hex" + + "github.com/bodgit/tsig" + "github.com/go-logr/logr" + multierror "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" +) + +var _ dns.TsigProvider = new(Client) + +// Generate generates the TSIG MAC based on the established context. +// It is called with the bytes of the DNS message, and the partial TSIG +// record containing the algorithm and name which is the negotiated TKEY +// for this context. +// It returns the bytes for the TSIG MAC and any error that occurred. +func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { + if dns.CanonicalName(t.Algorithm) != tsig.GSS { + return nil, dns.ErrKeyAlg + } + + c.m.RLock() + defer c.m.RUnlock() + + ctx, ok := c.ctx[t.Hdr.Name] + if !ok { + return nil, dns.ErrSecret + } + + return c.generate(ctx, msg) +} + +// Verify verifies the TSIG MAC based on the established context. +// It is called with the bytes of the DNS message, and the TSIG record +// containing the algorithm, MAC, and name which is the negotiated TKEY +// for this context. +// It returns any error that occurred. +func (c *Client) Verify(stripped []byte, t *dns.TSIG) error { + if dns.CanonicalName(t.Algorithm) != tsig.GSS { + return dns.ErrKeyAlg + } + + c.m.RLock() + defer c.m.RUnlock() + + ctx, ok := c.ctx[t.Hdr.Name] + if !ok { + return dns.ErrSecret + } + + mac, err := hex.DecodeString(t.MAC) + if err != nil { + return err + } + + return c.verify(ctx, stripped, mac) +} + +func (c *Client) close() error { + c.m.RLock() + + keys := make([]string, 0, len(c.ctx)) + for k := range c.ctx { + keys = append(keys, k) + } + + c.m.RUnlock() + + var errs error + for _, k := range keys { + errs = multierror.Append(errs, c.DeleteContext(k)) + } + + return errs +} + +func (c *Client) setOption(options ...func(*Client) error) error { + for _, option := range options { + if err := option(c); err != nil { + return err + } + } + + return nil +} + +// SetConfig sets the Kerberos configuration used by c. +func (c *Client) SetConfig(config string) error { + return c.setOption(WithConfig(config)) +} + +// WithLogger sets the logger used. +func WithLogger(logger logr.Logger) func(*Client) error { + return func(c *Client) error { + c.logger = logger.WithName("client") + + return nil + } +} + +// SetLogger sets the logger used by c. +func (c *Client) SetLogger(logger logr.Logger) error { + return c.setOption(WithLogger(logger)) +} diff --git a/gss/gokrb5.go b/gss/gokrb5.go index eec7bc0..8dfcbad 100644 --- a/gss/gokrb5.go +++ b/gss/gokrb5.go @@ -67,50 +67,11 @@ func (c *Client) Close() error { return c.close() } -// Generate generates the TSIG MAC based on the established context. -// It is called with the bytes of the DNS message, and the partial TSIG -// record containing the algorithm and name which is the negotiated TKEY -// for this context. -// It returns the bytes for the TSIG MAC and any error that occurred. -func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { - if dns.CanonicalName(t.Algorithm) != tsig.GSS { - return nil, dns.ErrKeyAlg - } - - c.m.RLock() - defer c.m.RUnlock() - - ctx, ok := c.ctx[t.Hdr.Name] - if !ok { - return nil, dns.ErrSecret - } - +func (c *Client) generate(ctx *wrapper.Initiator, msg []byte) ([]byte, error) { return ctx.MakeSignature(msg) } -// Verify verifies the TSIG MAC based on the established context. -// It is called with the bytes of the DNS message, and the TSIG record -// containing the algorithm, MAC, and name which is the negotiated TKEY -// for this context. -// It returns any error that occurred. -func (c *Client) Verify(stripped []byte, t *dns.TSIG) error { - if dns.CanonicalName(t.Algorithm) != tsig.GSS { - return dns.ErrKeyAlg - } - - c.m.RLock() - defer c.m.RUnlock() - - ctx, ok := c.ctx[t.Hdr.Name] - if !ok { - return dns.ErrSecret - } - - mac, err := hex.DecodeString(t.MAC) - if err != nil { - return err - } - +func (c *Client) verify(ctx *wrapper.Initiator, stripped, mac []byte) error { return ctx.VerifySignature(stripped, mac) } diff --git a/gss/gss.go b/gss/gss.go index cd7f658..b4f5567 100644 --- a/gss/gss.go +++ b/gss/gss.go @@ -78,8 +78,6 @@ import ( "math/big" "github.com/bodgit/tsig" - "github.com/go-logr/logr" - multierror "github.com/hashicorp/go-multierror" "github.com/miekg/dns" ) @@ -129,50 +127,3 @@ func generateSPN(host string) string { return fmt.Sprintf("DNS/%s", host) } - -func (c *Client) close() error { - c.m.RLock() - - keys := make([]string, 0, len(c.ctx)) - for k := range c.ctx { - keys = append(keys, k) - } - - c.m.RUnlock() - - var errs error - for _, k := range keys { - errs = multierror.Append(errs, c.DeleteContext(k)) - } - - return errs -} - -func (c *Client) setOption(options ...func(*Client) error) error { - for _, option := range options { - if err := option(c); err != nil { - return err - } - } - - return nil -} - -// SetConfig sets the Kerberos configuration used by c. -func (c *Client) SetConfig(config string) error { - return c.setOption(WithConfig(config)) -} - -// WithLogger sets the logger used. -func WithLogger(logger logr.Logger) func(*Client) error { - return func(c *Client) error { - c.logger = logger.WithName("client") - - return nil - } -} - -// SetLogger sets the logger used by c. -func (c *Client) SetLogger(logger logr.Logger) error { - return c.setOption(WithLogger(logger)) -} diff --git a/gss/sspi.go b/gss/sspi.go index e513087..d9d0dc8 100644 --- a/gss/sspi.go +++ b/gss/sspi.go @@ -65,60 +65,14 @@ func (c *Client) Close() error { return c.close() } -// Generate generates the TSIG MAC based on the established context. -// It is called with the bytes of the DNS message, and the partial TSIG -// record containing the algorithm and name which is the negotiated TKEY -// for this context. -// It returns the bytes for the TSIG MAC and any error that occurred. -func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { - if dns.CanonicalName(t.Algorithm) != tsig.GSS { - return nil, dns.ErrKeyAlg - } - - c.m.RLock() - defer c.m.RUnlock() - - ctx, ok := c.ctx[t.Hdr.Name] - if !ok { - return nil, dns.ErrSecret - } - - token, err := ctx.MakeSignature(msg, 0, 0) - if err != nil { - return nil, err - } - - return token, nil +func (c *Client) generate(ctx *negotiate.ClientContext, msg []byte) ([]byte, error) { + return ctx.MakeSignature(msg, 0, 0) } -// Verify verifies the TSIG MAC based on the established context. -// It is called with the bytes of the DNS message, and the TSIG record -// containing the algorithm, MAC, and name which is the negotiated TKEY -// for this context. -// It returns any error that occurred. -func (c *Client) Verify(stripped []byte, t *dns.TSIG) error { - if dns.CanonicalName(t.Algorithm) != tsig.GSS { - return dns.ErrKeyAlg - } - - c.m.RLock() - defer c.m.RUnlock() - - ctx, ok := c.ctx[t.Hdr.Name] - if !ok { - return dns.ErrSecret - } - - token, err := hex.DecodeString(t.MAC) - if err != nil { - return err - } - - if _, err = ctx.VerifySignature(stripped, token, 0); err != nil { - return err - } +func (c *Client) verify(ctx *negotiate.ClientContext, stripped, mac []byte) error { + _, err := ctx.VerifySignature(stripped, mac, 0) - return nil + return err } func (c *Client) negotiateContext(host string, creds *sspi.Credentials) (string, time.Time, error) {