Skip to content

Commit

Permalink
Add support for keyringless auth tokens (#3765)
Browse files Browse the repository at this point in the history
* Add support for keyringless auth tokens
  • Loading branch information
talanknight committed Sep 26, 2023
1 parent 5d77193 commit 1bf6dd5
Show file tree
Hide file tree
Showing 23 changed files with 1,638 additions and 540 deletions.
37 changes: 25 additions & 12 deletions internal/cmd/commands/daemon/addtoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,39 @@ func (c *AddTokenCommand) Run(args []string) int {

func (c *AddTokenCommand) Add(ctx context.Context) (*api.Error, error) {
const op = "daemon.(AddTokenCommand).Add"
keyringType, tokenName, err := c.DiscoverKeyringTokenInfo()
client, err := c.Client()
if err != nil {
return nil, err
}
at := c.ReadTokenFromKeyring(keyringType, tokenName)
if at == nil {
return nil, errors.New(ctx, errors.Conflict, op, "no auth token available to send to daemon")
}
client, err := c.Client()

keyringType, tokenName, err := c.DiscoverKeyringTokenInfo()
if err != nil {
return nil, err
}

pa := userTokenToAdd{
Keyring: &keyringToken{
pa := upsertTokenRequest{
BoundaryAddr: client.Addr(),
}
switch keyringType {
case "", base.NoneKeyring:
keyringType = base.NoneKeyring
token := client.Token()
if parts := strings.SplitN(token, "_", 4); len(parts) == 3 {
pa.AuthTokenId = strings.Join(parts[:2], "_")
} else {
return nil, errors.New(ctx, errors.InvalidParameter, op, "found auth token is not in the proper format")
}
pa.AuthToken = token
default:
at := c.ReadTokenFromKeyring(keyringType, tokenName)
if at == nil {
return nil, errors.New(ctx, errors.Conflict, op, "no auth token available to send to daemon")
}
pa.Keyring = &keyringToken{
KeyringType: keyringType,
TokenName: tokenName,
},
BoundaryAddr: client.Addr(),
AuthTokenId: at.Id,
}
pa.AuthTokenId = at.Id
}

dotPath, err := DefaultDotDirectory(ctx)
Expand All @@ -109,7 +122,7 @@ func (c *AddTokenCommand) Add(ctx context.Context) (*api.Error, error) {
return addToken(ctx, dotPath, &pa)
}

func addToken(ctx context.Context, daemonPath string, p *userTokenToAdd) (*api.Error, error) {
func addToken(ctx context.Context, daemonPath string, p *upsertTokenRequest) (*api.Error, error) {
const op = "daemon.addToken"
client, err := api.NewClient(nil)
if err != nil {
Expand Down
13 changes: 12 additions & 1 deletion internal/cmd/commands/daemon/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ package daemon

import (
"context"

"github.com/hashicorp/boundary/internal/daemon/cache"
)

type options struct {
withDebug bool
withDebug bool
withBoundaryTokenReaderFunc cache.BoundaryTokenReaderFn
}

// Option - how options are passed as args
Expand Down Expand Up @@ -36,3 +39,11 @@ func WithDebug(_ context.Context, debug bool) Option {
return nil
}
}

// WithBoundaryTokenReaderFunc provides an option for specifying a BoundaryTokenReaderFn
func WithBoundaryTokenReaderFunc(_ context.Context, fn cache.BoundaryTokenReaderFn) Option {
return func(o *options) error {
o.withBoundaryTokenReaderFunc = fn
return nil
}
}
46 changes: 46 additions & 0 deletions internal/cmd/commands/daemon/options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package daemon

import (
"context"
"testing"

"github.com/hashicorp/boundary/api/authtokens"
"github.com/hashicorp/boundary/internal/daemon/cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_GetOpts(t *testing.T) {
t.Parallel()
ctx := context.Background()

t.Run("default", func(t *testing.T) {
opts, err := getOpts()
require.NoError(t, err)
testOpts := options{}
assert.Equal(t, opts, testOpts)
})
t.Run("WithDebug", func(t *testing.T) {
opts, err := getOpts(WithDebug(ctx, true))
require.NoError(t, err)
testOpts := getDefaultOptions()
testOpts.withDebug = true
assert.Equal(t, opts, testOpts)
})
t.Run("WithBoundaryTokenReaderFunc", func(t *testing.T) {
var f cache.BoundaryTokenReaderFn = func(ctx context.Context, addr, token string) (*authtokens.AuthToken, error) {
return nil, nil
}
opts, err := getOpts(WithBoundaryTokenReaderFunc(ctx, f))
require.NoError(t, err)

assert.NotNil(t, opts.withBoundaryTokenReaderFunc)
opts.withBoundaryTokenReaderFunc = nil

testOpts := getDefaultOptions()
assert.Equal(t, opts, testOpts)
})
}
58 changes: 54 additions & 4 deletions internal/cmd/commands/daemon/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@ import (
// Commander is an interface that provides a way to get an apiClient
// and retrieve the keyring and token information used by a command.
type Commander interface {
Client(opt ...base.Option) (*api.Client, error)
ClientProvider
DiscoverKeyringTokenInfo() (string, string, error)
ReadTokenFromKeyring(keyringType, tokenName string) *authtokens.AuthToken
}

// ClientProvider is an interface that provides an api.Client
type ClientProvider interface {
Client(opt ...base.Option) (*api.Client, error)
}

type cacheServer struct {
conf *serverConfig

Expand Down Expand Up @@ -117,22 +122,67 @@ func (s *cacheServer) shutdown(ctx context.Context) error {
return shutdownErr
}

func defaultBoundaryTokenReader(ctx context.Context, cp ClientProvider) (cache.BoundaryTokenReaderFn, error) {
const op = "daemon.defaultBoundaryTokenReader"
switch {
case util.IsNil(cp):
return nil, errors.New(ctx, errors.InvalidParameter, op, "client provider is nil")
}
return func(ctx context.Context, addr, tok string) (*authtokens.AuthToken, error) {
switch {
case addr == "":
return nil, errors.New(ctx, errors.InvalidParameter, op, "address is missing")
case tok == "":
return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token is missing")
}
atIdParts := strings.SplitN(tok, "_", 4)
if len(atIdParts) != 3 {
return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token is malformed")
}
atId := strings.Join(atIdParts[:cache.AuthTokenIdSegmentCount], "_")

c, err := cp.Client()
if err != nil {
return nil, err
}
c.SetAddr(addr)
c.SetToken(tok)
atClient := authtokens.NewClient(c)

at, err := atClient.Read(ctx, atId)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return at.GetItem(), nil
}, nil
}

// start will fire up the refresh goroutine and the caching API http server as a
// daemon. The daemon bits are included so it's easy for CLI cmds to start the
// a cache server
func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener) error {
func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener, opt ...Option) error {
const op = "daemon.(cacheServer).start"
switch {
case util.IsNil(ctx):
return errors.New(ctx, errors.InvalidParameter, op, "context is missing")
}

opts, err := getOpts(opt...)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if opts.withBoundaryTokenReaderFunc == nil {
opts.withBoundaryTokenReaderFunc, err = defaultBoundaryTokenReader(ctx, cmd)
if err != nil {
return errors.Wrap(ctx, err, op)
}
}

s.info["Listening address"] = l.Addr().String()
s.infoKeys = append(s.infoKeys, "Listening address")
s.info["Store debug"] = strconv.FormatBool(s.conf.flagStoreDebug)
s.infoKeys = append(s.infoKeys, "Store debug")

var err error
if s.store, s.storeUrl, err = openStore(ctx, s.conf.flagDatabaseUrl, s.conf.flagStoreDebug); err != nil {
return errors.Wrap(ctx, err, op)
}
Expand All @@ -141,7 +191,7 @@ func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener)

s.printInfo(ctx)

repo, err := cache.NewRepository(ctx, s.store, cmd.ReadTokenFromKeyring)
repo, err := cache.NewRepository(ctx, s.store, &sync.Map{}, cmd.ReadTokenFromKeyring, opts.withBoundaryTokenReaderFunc)
if err != nil {
return errors.Wrap(ctx, err, op)
}
Expand Down
97 changes: 97 additions & 0 deletions internal/cmd/commands/daemon/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package daemon

import (
"context"
"testing"

"github.com/hashicorp/boundary/api"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/daemon/controller"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// Note: the name of this test must remain short because the temp dir created
// includes the name of the test and there is a 108 character limit in allowed
// unix socket path names.
func TestDefaultBoundaryTokenReader(t *testing.T) {
ctx := context.Background()

t.Run("nil client provider", func(t *testing.T) {
resFn, err := defaultBoundaryTokenReader(ctx, nil)
assert.Error(t, err)
assert.ErrorContains(t, err, "client provider is nil")
assert.Nil(t, resFn)
})

tc := controller.NewTestController(t, nil)
cp := fakeClientProvider{tc}

cases := []struct {
name string
address string
token string
errContains string
}{
{
name: "success",
address: tc.ApiAddrs()[0],
token: tc.Token().Token,
errContains: "",
},
{
name: "empty address",
address: "",
token: "at_123_testtoken",
errContains: "address is missing",
},
{
name: "empty token",
address: tc.ApiAddrs()[0],
token: "",
errContains: "auth token is missing",
},
{
name: "malformed token to many sections",
address: tc.ApiAddrs()[0],
token: "at_123_ignoredtoken_tomanysections",
errContains: "auth token is malformed",
},
{
name: "malformed token to few sections",
address: tc.ApiAddrs()[0],
token: "at_123",
errContains: "auth token is malformed",
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
retFn, err := defaultBoundaryTokenReader(ctx, cp)
require.NoError(t, err)
require.NotNil(t, retFn)

at, err := retFn(ctx, tc.address, tc.token)
switch tc.errContains {
case "":
assert.NoError(t, err)
assert.NotNil(t, at)
default:
assert.Error(t, err)
assert.ErrorContains(t, err, tc.errContains)
assert.Nil(t, at)
}
})
}
}

type fakeClientProvider struct {
*controller.TestController
}

func (fcp fakeClientProvider) Client(opt ...base.Option) (*api.Client, error) {
return fcp.TestController.Client(), nil
}
22 changes: 15 additions & 7 deletions internal/cmd/commands/daemon/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ package daemon

import (
"context"
stdErrors "errors"
"io"
"strings"
"sync"
"testing"

"github.com/hashicorp/boundary/api/authtokens"
"github.com/hashicorp/boundary/api/sessions"
"github.com/hashicorp/boundary/api/targets"
"github.com/hashicorp/boundary/internal/daemon/cache"
Expand Down Expand Up @@ -54,7 +56,7 @@ func (s *TestServer) BaseSocketDir() string {

// Serve runs the cache server. This is a blocking call and returns when the
// server is shutdown or stops for any other reason.
func (s *TestServer) Serve(t *testing.T) error {
func (s *TestServer) Serve(t *testing.T, opt ...Option) error {
t.Helper()
ctx := context.Background()

Expand All @@ -64,28 +66,34 @@ func (s *TestServer) Serve(t *testing.T) error {
t.Cleanup(func() {
s.shutdown(ctx)
})
return s.cacheServer.serve(ctx, s.cmd, l)
return s.cacheServer.serve(ctx, s.cmd, l, opt...)
}

// AddResources adds targets to the cache for the provided address, token name,
// and keyring type. They token info must already be known to the server.
func (s *TestServer) AddResources(t *testing.T, p *cache.AuthToken, tars []*targets.Target, sess []*sessions.Session) {
func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, tars []*targets.Target, sess []*sessions.Session) {
t.Helper()
ctx := context.Background()
r, err := cache.NewRepository(ctx, s.cacheServer.store, s.cmd.ReadTokenFromKeyring)
r, err := cache.NewRepository(ctx, s.cacheServer.store, &sync.Map{}, s.cmd.ReadTokenFromKeyring, unimplementedAuthTokenReader)
require.NoError(t, err)

tarFn := func(ctx context.Context, _, tok string) ([]*targets.Target, error) {
if !strings.HasPrefix(tok, p.Id) {
if tok != p.Token {
return nil, nil
}
return tars, nil
}
sessFn := func(ctx context.Context, _, tok string) ([]*sessions.Session, error) {
if !strings.HasPrefix(tok, p.Id) {
if tok != p.Token {
return nil, nil
}
return sess, nil
}
require.NoError(t, r.Refresh(ctx, cache.WithTargetRetrievalFunc(tarFn), cache.WithSessionRetrievalFunc(sessFn)))
}

// unimplementedAuthTokenReader is an unimplemented function for reading auth
// tokens from a provided boundary address.
func unimplementedAuthTokenReader(ctx context.Context, addr string, authToken string) (*authtokens.AuthToken, error) {
return nil, stdErrors.New("unimplemented")
}
Loading

0 comments on commit 1bf6dd5

Please sign in to comment.