Skip to content

Commit

Permalink
refactor: Minimise duplicated client TSIG code (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
bodgit authored Oct 1, 2023
1 parent 5e5ddfb commit 8b85393
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 182 deletions.
43 changes: 2 additions & 41 deletions gss/apcera.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,7 @@ func (c *Client) Close() error {
return multierror.Append(c.close(), c.lib.Unload())
}

// Generate generates the TSIG MAC based on the established context.
// It is called with the bytes of the DNS message, and the partial TSIG
// record containing the algorithm and name which is the negotiated TKEY
// for this context.
// It returns the bytes for the TSIG MAC and any error that occurred.
func (c *Client) Generate(msg []byte, t *dns.TSIG) (b []byte, err error) {
if dns.CanonicalName(t.Algorithm) != tsig.GSS {
return nil, dns.ErrKeyAlg
}

c.m.RLock()
defer c.m.RUnlock()

ctx, ok := c.ctx[t.Hdr.Name]
if !ok {
return nil, dns.ErrSecret
}

func (c *Client) generate(ctx *gssapi.CtxId, msg []byte) ([]byte, error) {
message, err := c.lib.MakeBufferBytes(msg)
if err != nil {
return nil, err
Expand All @@ -110,24 +93,7 @@ func (c *Client) Generate(msg []byte, t *dns.TSIG) (b []byte, err error) {
return token.Bytes(), nil
}

// Verify verifies the TSIG MAC based on the established context.
// It is called with the bytes of the DNS message, and the TSIG record
// containing the algorithm, MAC, and name which is the negotiated TKEY
// for this context.
// It returns any error that occurred.
func (c *Client) Verify(stripped []byte, t *dns.TSIG) (err error) {
if dns.CanonicalName(t.Algorithm) != tsig.GSS {
return dns.ErrKeyAlg
}

c.m.RLock()
defer c.m.RUnlock()

ctx, ok := c.ctx[t.Hdr.Name]
if !ok {
return dns.ErrSecret
}

func (c *Client) verify(ctx *gssapi.CtxId, stripped, mac []byte) error {
// Turn the TSIG-stripped message bytes into a *gssapi.Buffer
message, err := c.lib.MakeBufferBytes(stripped)
if err != nil {
Expand All @@ -138,11 +104,6 @@ func (c *Client) Verify(stripped []byte, t *dns.TSIG) (err error) {
err = multierror.Append(err, message.Release()).ErrorOrNil()
}()

mac, err := hex.DecodeString(t.MAC)
if err != nil {
return err
}

// Turn the TSIG MAC bytes into a *gssapi.Buffer
token, err := c.lib.MakeBufferBytes(mac)
if err != nil {
Expand Down
106 changes: 106 additions & 0 deletions gss/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package gss

import (
"encoding/hex"

"github.com/bodgit/tsig"
"github.com/go-logr/logr"
multierror "github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
)

var _ dns.TsigProvider = new(Client)

// Generate generates the TSIG MAC based on the established context.
// It is called with the bytes of the DNS message, and the partial TSIG
// record containing the algorithm and name which is the negotiated TKEY
// for this context.
// It returns the bytes for the TSIG MAC and any error that occurred.
func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) {
if dns.CanonicalName(t.Algorithm) != tsig.GSS {
return nil, dns.ErrKeyAlg
}

c.m.RLock()
defer c.m.RUnlock()

ctx, ok := c.ctx[t.Hdr.Name]
if !ok {
return nil, dns.ErrSecret
}

return c.generate(ctx, msg)
}

// Verify verifies the TSIG MAC based on the established context.
// It is called with the bytes of the DNS message, and the TSIG record
// containing the algorithm, MAC, and name which is the negotiated TKEY
// for this context.
// It returns any error that occurred.
func (c *Client) Verify(stripped []byte, t *dns.TSIG) error {
if dns.CanonicalName(t.Algorithm) != tsig.GSS {
return dns.ErrKeyAlg
}

c.m.RLock()
defer c.m.RUnlock()

ctx, ok := c.ctx[t.Hdr.Name]
if !ok {
return dns.ErrSecret
}

mac, err := hex.DecodeString(t.MAC)
if err != nil {
return err
}

return c.verify(ctx, stripped, mac)
}

func (c *Client) close() error {
c.m.RLock()

keys := make([]string, 0, len(c.ctx))
for k := range c.ctx {
keys = append(keys, k)
}

c.m.RUnlock()

var errs error
for _, k := range keys {
errs = multierror.Append(errs, c.DeleteContext(k))
}

return errs
}

func (c *Client) setOption(options ...func(*Client) error) error {
for _, option := range options {
if err := option(c); err != nil {
return err
}
}

return nil
}

// SetConfig sets the Kerberos configuration used by c.
func (c *Client) SetConfig(config string) error {
return c.setOption(WithConfig(config))
}

// WithLogger sets the logger used.
func WithLogger(logger logr.Logger) func(*Client) error {
return func(c *Client) error {
c.logger = logger.WithName("client")

return nil
}
}

// SetLogger sets the logger used by c.
func (c *Client) SetLogger(logger logr.Logger) error {
return c.setOption(WithLogger(logger))
}
43 changes: 2 additions & 41 deletions gss/gokrb5.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,50 +67,11 @@ func (c *Client) Close() error {
return c.close()
}

// Generate generates the TSIG MAC based on the established context.
// It is called with the bytes of the DNS message, and the partial TSIG
// record containing the algorithm and name which is the negotiated TKEY
// for this context.
// It returns the bytes for the TSIG MAC and any error that occurred.
func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) {
if dns.CanonicalName(t.Algorithm) != tsig.GSS {
return nil, dns.ErrKeyAlg
}

c.m.RLock()
defer c.m.RUnlock()

ctx, ok := c.ctx[t.Hdr.Name]
if !ok {
return nil, dns.ErrSecret
}

func (c *Client) generate(ctx *wrapper.Initiator, msg []byte) ([]byte, error) {
return ctx.MakeSignature(msg)
}

// Verify verifies the TSIG MAC based on the established context.
// It is called with the bytes of the DNS message, and the TSIG record
// containing the algorithm, MAC, and name which is the negotiated TKEY
// for this context.
// It returns any error that occurred.
func (c *Client) Verify(stripped []byte, t *dns.TSIG) error {
if dns.CanonicalName(t.Algorithm) != tsig.GSS {
return dns.ErrKeyAlg
}

c.m.RLock()
defer c.m.RUnlock()

ctx, ok := c.ctx[t.Hdr.Name]
if !ok {
return dns.ErrSecret
}

mac, err := hex.DecodeString(t.MAC)
if err != nil {
return err
}

func (c *Client) verify(ctx *wrapper.Initiator, stripped, mac []byte) error {
return ctx.VerifySignature(stripped, mac)
}

Expand Down
49 changes: 0 additions & 49 deletions gss/gss.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ import (
"math/big"

"github.com/bodgit/tsig"
"github.com/go-logr/logr"
multierror "github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
)

Expand Down Expand Up @@ -129,50 +127,3 @@ func generateSPN(host string) string {

return fmt.Sprintf("DNS/%s", host)
}

func (c *Client) close() error {
c.m.RLock()

keys := make([]string, 0, len(c.ctx))
for k := range c.ctx {
keys = append(keys, k)
}

c.m.RUnlock()

var errs error
for _, k := range keys {
errs = multierror.Append(errs, c.DeleteContext(k))
}

return errs
}

func (c *Client) setOption(options ...func(*Client) error) error {
for _, option := range options {
if err := option(c); err != nil {
return err
}
}

return nil
}

// SetConfig sets the Kerberos configuration used by c.
func (c *Client) SetConfig(config string) error {
return c.setOption(WithConfig(config))
}

// WithLogger sets the logger used.
func WithLogger(logger logr.Logger) func(*Client) error {
return func(c *Client) error {
c.logger = logger.WithName("client")

return nil
}
}

// SetLogger sets the logger used by c.
func (c *Client) SetLogger(logger logr.Logger) error {
return c.setOption(WithLogger(logger))
}
56 changes: 5 additions & 51 deletions gss/sspi.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,60 +65,14 @@ func (c *Client) Close() error {
return c.close()
}

// Generate generates the TSIG MAC based on the established context.
// It is called with the bytes of the DNS message, and the partial TSIG
// record containing the algorithm and name which is the negotiated TKEY
// for this context.
// It returns the bytes for the TSIG MAC and any error that occurred.
func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) {
if dns.CanonicalName(t.Algorithm) != tsig.GSS {
return nil, dns.ErrKeyAlg
}

c.m.RLock()
defer c.m.RUnlock()

ctx, ok := c.ctx[t.Hdr.Name]
if !ok {
return nil, dns.ErrSecret
}

token, err := ctx.MakeSignature(msg, 0, 0)
if err != nil {
return nil, err
}

return token, nil
func (c *Client) generate(ctx *negotiate.ClientContext, msg []byte) ([]byte, error) {
return ctx.MakeSignature(msg, 0, 0)
}

// Verify verifies the TSIG MAC based on the established context.
// It is called with the bytes of the DNS message, and the TSIG record
// containing the algorithm, MAC, and name which is the negotiated TKEY
// for this context.
// It returns any error that occurred.
func (c *Client) Verify(stripped []byte, t *dns.TSIG) error {
if dns.CanonicalName(t.Algorithm) != tsig.GSS {
return dns.ErrKeyAlg
}

c.m.RLock()
defer c.m.RUnlock()

ctx, ok := c.ctx[t.Hdr.Name]
if !ok {
return dns.ErrSecret
}

token, err := hex.DecodeString(t.MAC)
if err != nil {
return err
}

if _, err = ctx.VerifySignature(stripped, token, 0); err != nil {
return err
}
func (c *Client) verify(ctx *negotiate.ClientContext, stripped, mac []byte) error {
_, err := ctx.VerifySignature(stripped, mac, 0)

return nil
return err
}

func (c *Client) negotiateContext(host string, creds *sspi.Credentials) (string, time.Time, error) {
Expand Down

0 comments on commit 8b85393

Please sign in to comment.