From be119ebd615008efeb3d1341aecb5f1e3e723760 Mon Sep 17 00:00:00 2001 From: Todd Date: Mon, 9 Oct 2023 16:29:53 -0700 Subject: [PATCH] Refactor for self contained domain logic --- .../{daemon/cache/store.go => cache/db/db.go} | 27 +++-- .../cache/store_cgo.go => cache/db/db_cgo.go} | 2 +- .../store_nocgo.go => cache/db/db_nocgo.go} | 2 +- .../commands/daemon => cache/db}/options.go | 27 ++--- .../daemon => cache/db}/options_test.go | 30 +++--- .../{daemon/cache => cache/db}/schema.sql | 0 internal/cache/options.go | 59 +++++++++++ internal/cache/options_test.go | 59 +++++++++++ .../refresh.go} | 44 ++++---- .../refresh_test.go} | 56 +++++----- internal/{daemon => }/cache/repository.go | 8 +- .../{daemon => }/cache/repository_sessions.go | 0 .../cache/repository_sessions_test.go | 9 +- .../{daemon => }/cache/repository_targets.go | 0 .../cache/repository_targets_test.go | 9 +- .../{daemon => }/cache/repository_test.go | 3 +- .../{daemon => }/cache/repository_token.go | 0 .../cache/repository_token_test.go | 31 +++--- internal/{daemon => }/cache/search.go | 0 internal/{daemon => }/cache/search_test.go | 9 +- internal/{daemon => }/cache/store_test.go | 30 +++--- internal/cmd/commands/daemon/addtoken.go | 11 +- .../cmd/commands/daemon/command_wrapper.go | 16 +-- internal/cmd/commands/daemon/start.go | 34 +++--- internal/cmd/commands/daemon/token_test.go | 70 ++++++------ internal/cmd/commands/search/search.go | 7 +- internal/cmd/commands/search/search_test.go | 28 ++--- .../daemon => daemon/cache}/handlers.go | 2 +- .../daemon => daemon/cache}/listener.go | 6 +- .../cache}/listener_nonwindows_test.go | 4 +- .../daemon => daemon/cache}/listener_test.go | 4 +- internal/daemon/cache/options.go | 48 ++------- internal/daemon/cache/options_test.go | 48 ++------- internal/daemon/cache/pidfile_nonwindows.go | 73 +++++++++++++ internal/daemon/cache/pidfile_test.go | 61 +++++++++++ internal/daemon/cache/pidfile_windows.go | 100 ++++++++++++++++++ internal/daemon/cache/query.go | 4 - .../daemon => daemon/cache}/search_handler.go | 4 +- .../daemon => daemon/cache}/server.go | 90 +++++++++------- .../daemon => daemon/cache}/server_test.go | 2 +- .../daemon => daemon/cache}/stop_handler.go | 2 +- .../cache}/stop_handler_test.go | 2 +- .../daemon => daemon/cache}/testing.go | 33 +++--- .../daemon => daemon/cache}/ticker.go | 19 ++-- .../daemon => daemon/cache}/ticker_test.go | 27 +++-- .../daemon => daemon/cache}/token_handler.go | 14 +-- .../cache}/version_interceptor.go | 2 +- .../cache}/version_interceptor_test.go | 2 +- 48 files changed, 719 insertions(+), 399 deletions(-) rename internal/{daemon/cache/store.go => cache/db/db.go} (67%) rename internal/{daemon/cache/store_cgo.go => cache/db/db_cgo.go} (90%) rename internal/{daemon/cache/store_nocgo.go => cache/db/db_nocgo.go} (91%) rename internal/{cmd/commands/daemon => cache/db}/options.go (55%) rename internal/{cmd/commands/daemon => cache/db}/options_test.go (51%) rename internal/{daemon/cache => cache/db}/schema.sql (100%) create mode 100644 internal/cache/options.go create mode 100644 internal/cache/options_test.go rename internal/{daemon/cache/repository_refresh.go => cache/refresh.go} (68%) rename internal/{daemon/cache/repository_refresh_test.go => cache/refresh_test.go} (90%) rename internal/{daemon => }/cache/repository.go (91%) rename internal/{daemon => }/cache/repository_sessions.go (100%) rename internal/{daemon => }/cache/repository_sessions_test.go (97%) rename internal/{daemon => }/cache/repository_targets.go (100%) rename internal/{daemon => }/cache/repository_targets_test.go (98%) rename internal/{daemon => }/cache/repository_test.go (96%) rename internal/{daemon => }/cache/repository_token.go (100%) rename internal/{daemon => }/cache/repository_token_test.go (97%) rename internal/{daemon => }/cache/search.go (100%) rename internal/{daemon => }/cache/search_test.go (97%) rename internal/{daemon => }/cache/store_test.go (96%) rename internal/{cmd/commands/daemon => daemon/cache}/handlers.go (97%) rename internal/{cmd/commands/daemon => daemon/cache}/listener.go (92%) rename internal/{cmd/commands/daemon => daemon/cache}/listener_nonwindows_test.go (95%) rename internal/{cmd/commands/daemon => daemon/cache}/listener_test.go (95%) create mode 100644 internal/daemon/cache/pidfile_nonwindows.go create mode 100644 internal/daemon/cache/pidfile_test.go create mode 100644 internal/daemon/cache/pidfile_windows.go delete mode 100644 internal/daemon/cache/query.go rename internal/{cmd/commands/daemon => daemon/cache}/search_handler.go (97%) rename internal/{cmd/commands/daemon => daemon/cache}/server.go (82%) rename internal/{cmd/commands/daemon => daemon/cache}/server_test.go (99%) rename internal/{cmd/commands/daemon => daemon/cache}/stop_handler.go (98%) rename internal/{cmd/commands/daemon => daemon/cache}/stop_handler_test.go (98%) rename internal/{cmd/commands/daemon => daemon/cache}/testing.go (76%) rename internal/{cmd/commands/daemon => daemon/cache}/ticker.go (83%) rename internal/{cmd/commands/daemon => daemon/cache}/ticker_test.go (66%) rename internal/{cmd/commands/daemon => daemon/cache}/token_handler.go (94%) rename internal/{cmd/commands/daemon => daemon/cache}/version_interceptor.go (98%) rename internal/{cmd/commands/daemon => daemon/cache}/version_interceptor_test.go (98%) diff --git a/internal/daemon/cache/store.go b/internal/cache/db/db.go similarity index 67% rename from internal/daemon/cache/store.go rename to internal/cache/db/db.go index 88e64827ffc..2d44c7146c1 100644 --- a/internal/daemon/cache/store.go +++ b/internal/cache/db/db.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package cache +package db import ( "context" @@ -19,12 +19,10 @@ var cacheSchema string // DefaultStoreUrl uses a temp in-memory sqlite database see: https://www.sqlite.org/inmemorydb.html const DefaultStoreUrl = "file::memory:?_pragma=foreign_keys(1)" -type Store struct { - conn *db.DB -} - -func Open(ctx context.Context, opt ...Option) (*Store, error) { - const op = "cache.Open" +// Open creates a database connection. WithUrl is supported, but by default it +// uses an in memory sqlite table. Sqlite is the only supported dbtype. +func Open(ctx context.Context, opt ...Option) (*db.DB, error) { + const op = "db.Open" opts, err := getOpts(opt...) if err != nil { return nil, errors.Wrap(ctx, err, op) @@ -36,27 +34,26 @@ func Open(ctx context.Context, opt ...Option) (*Store, error) { default: url = DefaultStoreUrl } - underlying, err := db.Open(ctx, db.Sqlite, url) + conn, err := db.Open(ctx, db.Sqlite, url) if err != nil { return nil, errors.Wrap(ctx, err, op) } - s := &Store{conn: underlying} - s.conn.Debug(opts.withDebug) + conn.Debug(opts.withDebug) switch { case opts.withDbType == dbw.Sqlite: - if err := s.createTables(ctx); err != nil { + if err := createTables(ctx, conn); err != nil { errors.Wrap(ctx, err, op) } default: return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%q is not a supported cache store type", opts.withDbType)) } - return s, nil + return conn, nil } -func (s *Store) createTables(ctx context.Context) error { - const op = "cache.(Store).createTables" - rw := db.New(s.conn) +func createTables(ctx context.Context, conn *db.DB) error { + const op = "db.createTables" + rw := db.New(conn) if _, err := rw.Exec(ctx, cacheSchema, nil); err != nil { return errors.Wrap(ctx, err, op) } diff --git a/internal/daemon/cache/store_cgo.go b/internal/cache/db/db_cgo.go similarity index 90% rename from internal/daemon/cache/store_cgo.go rename to internal/cache/db/db_cgo.go index eb17f54f9e1..c74c63e3826 100644 --- a/internal/daemon/cache/store_cgo.go +++ b/internal/cache/db/db_cgo.go @@ -4,7 +4,7 @@ //go:build cgo // +build cgo -package cache +package db import ( _ "gorm.io/driver/sqlite" diff --git a/internal/daemon/cache/store_nocgo.go b/internal/cache/db/db_nocgo.go similarity index 91% rename from internal/daemon/cache/store_nocgo.go rename to internal/cache/db/db_nocgo.go index ef2e4dd96c2..4dfdd37d7cf 100644 --- a/internal/daemon/cache/store_nocgo.go +++ b/internal/cache/db/db_nocgo.go @@ -4,7 +4,7 @@ //go:build !cgo // +build !cgo -package cache +package db import ( _ "github.com/glebarez/go-sqlite" diff --git a/internal/cmd/commands/daemon/options.go b/internal/cache/db/options.go similarity index 55% rename from internal/cmd/commands/daemon/options.go rename to internal/cache/db/options.go index 07ba3b6a71d..a44c622925f 100644 --- a/internal/cmd/commands/daemon/options.go +++ b/internal/cache/db/options.go @@ -1,24 +1,25 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package db import ( - "context" - - "github.com/hashicorp/boundary/internal/daemon/cache" + "github.com/hashicorp/go-dbw" ) type options struct { - withDebug bool - withBoundaryTokenReaderFunc cache.BoundaryTokenReaderFn + withDebug bool + withUrl string + withDbType dbw.DbType } // Option - how options are passed as args type Option func(*options) error func getDefaultOptions() options { - return options{} + return options{ + withDbType: dbw.Sqlite, + } } func getOpts(opt ...Option) (options, error) { @@ -32,18 +33,18 @@ func getOpts(opt ...Option) (options, error) { return opts, nil } -// WithDebug provides an optional debug flag. -func WithDebug(_ context.Context, debug bool) Option { +// WithUrls provides optional url +func WithUrl(url string) Option { return func(o *options) error { - o.withDebug = debug + o.withUrl = url return nil } } -// WithBoundaryTokenReaderFunc provides an option for specifying a BoundaryTokenReaderFn -func WithBoundaryTokenReaderFunc(_ context.Context, fn cache.BoundaryTokenReaderFn) Option { +// WithDebug provides an optional debug flag. +func WithDebug(debug bool) Option { return func(o *options) error { - o.withBoundaryTokenReaderFunc = fn + o.withDebug = debug return nil } } diff --git a/internal/cmd/commands/daemon/options_test.go b/internal/cache/db/options_test.go similarity index 51% rename from internal/cmd/commands/daemon/options_test.go rename to internal/cache/db/options_test.go index 1c307e709e7..f57591ef600 100644 --- a/internal/cmd/commands/daemon/options_test.go +++ b/internal/cache/db/options_test.go @@ -1,46 +1,40 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package db import ( - "context" "testing" - "github.com/hashicorp/boundary/api/authtokens" - "github.com/hashicorp/boundary/internal/daemon/cache" + "github.com/hashicorp/go-dbw" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_GetOpts(t *testing.T) { t.Parallel() - ctx := context.Background() t.Run("default", func(t *testing.T) { opts, err := getOpts() require.NoError(t, err) - testOpts := options{} + testOpts := options{ + withDbType: dbw.Sqlite, + } assert.Equal(t, opts, testOpts) }) - t.Run("WithDebug", func(t *testing.T) { - opts, err := getOpts(WithDebug(ctx, true)) + t.Run("WithUrl", func(t *testing.T) { + url := "something" + opts, err := getOpts(WithUrl(url)) require.NoError(t, err) testOpts := getDefaultOptions() - testOpts.withDebug = true + testOpts.withUrl = url assert.Equal(t, opts, testOpts) }) - t.Run("WithBoundaryTokenReaderFunc", func(t *testing.T) { - var f cache.BoundaryTokenReaderFn = func(ctx context.Context, addr, token string) (*authtokens.AuthToken, error) { - return nil, nil - } - opts, err := getOpts(WithBoundaryTokenReaderFunc(ctx, f)) + t.Run("WithDebug", func(t *testing.T) { + opts, err := getOpts(WithDebug(true)) require.NoError(t, err) - - assert.NotNil(t, opts.withBoundaryTokenReaderFunc) - opts.withBoundaryTokenReaderFunc = nil - testOpts := getDefaultOptions() + testOpts.withDebug = true assert.Equal(t, opts, testOpts) }) } diff --git a/internal/daemon/cache/schema.sql b/internal/cache/db/schema.sql similarity index 100% rename from internal/daemon/cache/schema.sql rename to internal/cache/db/schema.sql diff --git a/internal/cache/options.go b/internal/cache/options.go new file mode 100644 index 00000000000..7a4e04b785d --- /dev/null +++ b/internal/cache/options.go @@ -0,0 +1,59 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "github.com/hashicorp/go-dbw" +) + +type options struct { + withUpdateLastAccessedTime bool + withDbType dbw.DbType + withTargetRetrievalFunc TargetRetrievalFunc + withSessionRetrievalFunc SessionRetrievalFunc +} + +// Option - how options are passed as args +type Option func(*options) error + +func getDefaultOptions() options { + return options{ + withDbType: dbw.Sqlite, + } +} + +func getOpts(opt ...Option) (options, error) { + opts := getDefaultOptions() + + for _, o := range opt { + if err := o(&opts); err != nil { + return opts, err + } + } + return opts, nil +} + +// WithUpdateLastAccessedTime provides an option for updating the last access time +func WithUpdateLastAccessedTime(b bool) Option { + return func(o *options) error { + o.withUpdateLastAccessedTime = b + return nil + } +} + +// WithTargetRetrievalFunc provides an option for specifying a targetRetrievalFunc +func WithTargetRetrievalFunc(fn TargetRetrievalFunc) Option { + return func(o *options) error { + o.withTargetRetrievalFunc = fn + return nil + } +} + +// WithSessionRetrievalFunc provides an option for specifying a sessionRetrievalFunc +func WithSessionRetrievalFunc(fn SessionRetrievalFunc) Option { + return func(o *options) error { + o.withSessionRetrievalFunc = fn + return nil + } +} diff --git a/internal/cache/options_test.go b/internal/cache/options_test.go new file mode 100644 index 00000000000..9e8e491bae3 --- /dev/null +++ b/internal/cache/options_test.go @@ -0,0 +1,59 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "context" + "testing" + + "github.com/hashicorp/boundary/api/sessions" + "github.com/hashicorp/boundary/api/targets" + "github.com/hashicorp/go-dbw" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetOpts(t *testing.T) { + t.Parallel() + + t.Run("default", func(t *testing.T) { + opts, err := getOpts() + require.NoError(t, err) + testOpts := options{ + withDbType: dbw.Sqlite, + } + assert.Equal(t, opts, testOpts) + }) + t.Run("WithUpdateLastAccessedTime", func(t *testing.T) { + opts, err := getOpts(WithUpdateLastAccessedTime(true)) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withUpdateLastAccessedTime = true + assert.Equal(t, opts, testOpts) + }) + t.Run("WithTargetRetrievalFunc", func(t *testing.T) { + var f TargetRetrievalFunc = func(ctx context.Context, keyringstring, tokenName string) ([]*targets.Target, error) { return nil, nil } + opts, err := getOpts(WithTargetRetrievalFunc(f)) + require.NoError(t, err) + + assert.NotNil(t, opts.withTargetRetrievalFunc) + opts.withTargetRetrievalFunc = nil + + testOpts := getDefaultOptions() + assert.Equal(t, opts, testOpts) + }) + t.Run("WithSessionRetrievalFunc", func(t *testing.T) { + var f SessionRetrievalFunc = func(ctx context.Context, keyringstring, tokenName string) ([]*sessions.Session, error) { + return nil, nil + } + opts, err := getOpts(WithSessionRetrievalFunc(f)) + require.NoError(t, err) + + assert.NotNil(t, opts.withSessionRetrievalFunc) + opts.withSessionRetrievalFunc = nil + + testOpts := getDefaultOptions() + assert.Equal(t, opts, testOpts) + }) +} diff --git a/internal/daemon/cache/repository_refresh.go b/internal/cache/refresh.go similarity index 68% rename from internal/daemon/cache/repository_refresh.go rename to internal/cache/refresh.go index c8447e59719..342f0ab7eac 100644 --- a/internal/daemon/cache/repository_refresh.go +++ b/internal/cache/refresh.go @@ -14,42 +14,50 @@ import ( "github.com/hashicorp/boundary/internal/util" ) +type RefreshService struct { + repo *Repository +} + +func NewRefreshService(ctx context.Context, r *Repository) (*RefreshService, error) { + return &RefreshService{repo: r}, nil +} + // cleanAndPickAuthTokens removes from the cache all auth tokens which are // evicted from the cache or no longer stored in a keyring and returns the // remaining ones. // The returned auth tokens have not been validated against boundary -func (r *Repository) cleanAndPickAuthTokens(ctx context.Context, u *user) (map[AuthToken]string, error) { - const op = "cache.(Repository).cleanAndPickAuthTokens" +func (r *RefreshService) cleanAndPickAuthTokens(ctx context.Context, u *user) (map[AuthToken]string, error) { + const op = "cache.(RefreshService).cleanAndPickAuthTokens" switch { case util.IsNil(u): return nil, errors.New(ctx, errors.InvalidParameter, op, "user is nil") } ret := make(map[AuthToken]string) - tokens, err := r.listTokens(ctx, u) + tokens, err := r.repo.listTokens(ctx, u) if err != nil { return nil, errors.Wrap(ctx, err, op) } for _, t := range tokens { - keyringTokens, err := r.listKeyringTokens(ctx, t) + keyringTokens, err := r.repo.listKeyringTokens(ctx, t) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("for user %v, auth token %q", u, t.Id)) } for _, kt := range keyringTokens { - at := r.tokenKeyringFn(kt.KeyringType, kt.TokenName) + at := r.repo.tokenKeyringFn(kt.KeyringType, kt.TokenName) switch { case at == nil, at.Id != kt.AuthTokenId, at.UserId != t.UserId: // delete the keyring token if the auth token in the keyring // has changed since it was stored in the cache. - if err := r.deleteKeyringToken(ctx, *kt); err != nil { + if err := r.repo.deleteKeyringToken(ctx, *kt); err != nil { return nil, errors.Wrap(ctx, err, op) } case at != nil: - _, err := r.tokenReadFromBoundaryFn(ctx, u.Address, at.Token) + _, err := r.repo.tokenReadFromBoundaryFn(ctx, u.Address, at.Token) var apiErr *api.Error switch { case err != nil && api.ErrUnauthorized.Is(err): - if err := r.deleteKeyringToken(ctx, *kt); err != nil { + if err := r.repo.deleteKeyringToken(ctx, *kt); err != nil { return nil, errors.Wrap(ctx, err, op) } continue @@ -60,13 +68,13 @@ func (r *Repository) cleanAndPickAuthTokens(ctx context.Context, u *user) (map[A ret[*t] = at.Token } } - if atv, ok := r.idToKeyringlessAuthToken.Load(t.Id); ok { + if atv, ok := r.repo.idToKeyringlessAuthToken.Load(t.Id); ok { if at, ok := atv.(*authtokens.AuthToken); ok { - _, err := r.tokenReadFromBoundaryFn(ctx, u.Address, at.Token) + _, err := r.repo.tokenReadFromBoundaryFn(ctx, u.Address, at.Token) var apiErr *api.Error switch { case err != nil && api.ErrUnauthorized.Is(err): - r.idToKeyringlessAuthToken.Delete(t.Id) + r.repo.idToKeyringlessAuthToken.Delete(t.Id) continue case err != nil && !errors.Is(err, apiErr): event.WriteError(ctx, op, err, event.WithInfoMsg("validating in memory stored token against boundary", "auth token id", at.Id)) @@ -85,16 +93,16 @@ func (r *Repository) cleanAndPickAuthTokens(ctx context.Context, u *user) (map[A // the values retrieved there. Refresh accepts the options // WithTarget which overwrites the default function used to retrieve the // targets from a boundary address. -func (r *Repository) Refresh(ctx context.Context, opt ...Option) error { - const op = "cache.(Repository).Refresh" - if err := r.cleanExpiredOrOrphanedAuthTokens(ctx); err != nil { +func (r *RefreshService) Refresh(ctx context.Context, opt ...Option) error { + const op = "cache.(RefreshService).Refresh" + if err := r.repo.cleanExpiredOrOrphanedAuthTokens(ctx); err != nil { return errors.Wrap(ctx, err, op) } - if err := r.syncKeyringlessTokensWithDb(ctx); err != nil { + if err := r.repo.syncKeyringlessTokensWithDb(ctx); err != nil { return errors.Wrap(ctx, err, op) } - us, err := r.listUsers(ctx) + us, err := r.repo.listUsers(ctx) if err != nil { return errors.Wrap(ctx, err, op) } @@ -106,10 +114,10 @@ func (r *Repository) Refresh(ctx context.Context, opt ...Option) error { continue } - if err := r.refreshTargets(ctx, u, tokens, opt...); err != nil { + if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil { retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op)) } - if err := r.refreshSessions(ctx, u, tokens, opt...); err != nil { + if err := r.repo.refreshSessions(ctx, u, tokens, opt...); err != nil { retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op)) } diff --git a/internal/daemon/cache/repository_refresh_test.go b/internal/cache/refresh_test.go similarity index 90% rename from internal/daemon/cache/repository_refresh_test.go rename to internal/cache/refresh_test.go index 47aec60b3b2..a6c4ff96c9f 100644 --- a/internal/daemon/cache/repository_refresh_test.go +++ b/internal/cache/refresh_test.go @@ -6,7 +6,6 @@ package cache import ( "context" "errors" - stdErrors "errors" "fmt" "sync" "testing" @@ -16,6 +15,7 @@ import ( "github.com/hashicorp/boundary/api/authtokens" "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" + "github.com/hashicorp/boundary/internal/cache/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/maps" @@ -33,7 +33,7 @@ func testStaticResourceRetrievalFunc[T any](ret []T) func(context.Context, strin func TestCleanAndPickTokens(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := db.Open(ctx) require.NoError(t, err) boundaryAddr := "address" @@ -84,7 +84,7 @@ func TestCleanAndPickTokens(t *testing.T) { return v, nil } } - return nil, stdErrors.New("not found") + return nil, errors.New("not found") } atMap := make(map[ringToken]*authtokens.AuthToken) @@ -92,9 +92,11 @@ func TestCleanAndPickTokens(t *testing.T) { mapBasedAuthTokenKeyringLookup(atMap), fakeBoundaryLookupFn) require.NoError(t, err) + rs, err := NewRefreshService(ctx, r) + require.NoError(t, err) t.Run("unknown user", func(t *testing.T) { - got, err := r.cleanAndPickAuthTokens(ctx, &user{Id: "unknownuser", Address: "unknown"}) + got, err := rs.cleanAndPickAuthTokens(ctx, &user{Id: "unknownuser", Address: "unknown"}) assert.NoError(t, err) assert.Empty(t, got) }) @@ -109,13 +111,13 @@ func TestCleanAndPickTokens(t *testing.T) { })) require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1b.Token)) - got, err := r.cleanAndPickAuthTokens(ctx, u1) + got, err := rs.cleanAndPickAuthTokens(ctx, u1) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token, at1b.Token}) // delete the keyringToken from the keyring and see it get removed from the response delete(atMap, key) - got, err = r.cleanAndPickAuthTokens(ctx, u1) + got, err = rs.cleanAndPickAuthTokens(ctx, u1) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{at1b.Token}) }) @@ -124,7 +126,7 @@ func TestCleanAndPickTokens(t *testing.T) { require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1a.Token)) require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1b.Token)) - got, err := r.cleanAndPickAuthTokens(ctx, u1) + got, err := rs.cleanAndPickAuthTokens(ctx, u1) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token, at1b.Token}) }) @@ -133,7 +135,7 @@ func TestCleanAndPickTokens(t *testing.T) { require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1a.Token)) require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1b.Token)) - got, err := r.cleanAndPickAuthTokens(ctx, u1) + got, err := rs.cleanAndPickAuthTokens(ctx, u1) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token, at1b.Token}) @@ -142,7 +144,7 @@ func TestCleanAndPickTokens(t *testing.T) { }) unauthorizedAuthTokens = []*authtokens.AuthToken{at1b} - got, err = r.cleanAndPickAuthTokens(ctx, u1) + got, err = rs.cleanAndPickAuthTokens(ctx, u1) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token}) }) @@ -163,7 +165,7 @@ func TestCleanAndPickTokens(t *testing.T) { AuthTokenId: keyringAuthToken2.Id, })) - got, err := r.cleanAndPickAuthTokens(ctx, keyringOnlyUser) + got, err := rs.cleanAndPickAuthTokens(ctx, keyringOnlyUser) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{keyringAuthToken1.Token, keyringAuthToken2.Token}) @@ -172,7 +174,7 @@ func TestCleanAndPickTokens(t *testing.T) { }) unauthorizedAuthTokens = []*authtokens.AuthToken{keyringAuthToken2} - got, err = r.cleanAndPickAuthTokens(ctx, keyringOnlyUser) + got, err = rs.cleanAndPickAuthTokens(ctx, keyringOnlyUser) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{keyringAuthToken1.Token}) }) @@ -181,7 +183,7 @@ func TestCleanAndPickTokens(t *testing.T) { require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1a.Token)) require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1b.Token)) - got, err := r.cleanAndPickAuthTokens(ctx, u1) + got, err := rs.cleanAndPickAuthTokens(ctx, u1) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token, at1b.Token}) @@ -190,7 +192,7 @@ func TestCleanAndPickTokens(t *testing.T) { }) randomErrorAuthTokens = []*authtokens.AuthToken{at1b} - got, err = r.cleanAndPickAuthTokens(ctx, u1) + got, err = rs.cleanAndPickAuthTokens(ctx, u1) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token}) }) @@ -211,7 +213,7 @@ func TestCleanAndPickTokens(t *testing.T) { AuthTokenId: keyringAuthToken2.Id, })) - got, err := r.cleanAndPickAuthTokens(ctx, keyringOnlyUser) + got, err := rs.cleanAndPickAuthTokens(ctx, keyringOnlyUser) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{keyringAuthToken1.Token, keyringAuthToken2.Token}) @@ -220,7 +222,7 @@ func TestCleanAndPickTokens(t *testing.T) { }) randomErrorAuthTokens = []*authtokens.AuthToken{keyringAuthToken2} - got, err = r.cleanAndPickAuthTokens(ctx, keyringOnlyUser) + got, err = rs.cleanAndPickAuthTokens(ctx, keyringOnlyUser) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{keyringAuthToken1.Token}) }) @@ -241,7 +243,7 @@ func TestCleanAndPickTokens(t *testing.T) { AuthTokenId: keyringAuthToken2.Id, })) - got, err := r.cleanAndPickAuthTokens(ctx, keyringOnlyUser) + got, err := rs.cleanAndPickAuthTokens(ctx, keyringOnlyUser) assert.NoError(t, err) assert.ElementsMatch(t, maps.Values(got), []string{keyringAuthToken1.Token, keyringAuthToken2.Token}) @@ -253,7 +255,7 @@ func TestCleanAndPickTokens(t *testing.T) { delete(atMap, key1) delete(atMap, key2) - got, err = r.cleanAndPickAuthTokens(ctx, keyringOnlyUser) + got, err = rs.cleanAndPickAuthTokens(ctx, keyringOnlyUser) assert.NoError(t, err) assert.Empty(t, got) @@ -268,7 +270,7 @@ func TestCleanAndPickTokens(t *testing.T) { func TestRefresh(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := db.Open(ctx) require.NoError(t, err) boundaryAddr := "address" @@ -284,6 +286,8 @@ func TestRefresh(t *testing.T) { atMap := make(map[ringToken]*authtokens.AuthToken) r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) require.NoError(t, err) + rs, err := NewRefreshService(ctx, r) + require.NoError(t, err) atMap[ringToken{"k", "t"}] = at require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) @@ -294,7 +298,7 @@ func TestRefresh(t *testing.T) { target("2"), target("3"), } - assert.NoError(t, r.Refresh(ctx, + assert.NoError(t, rs.Refresh(ctx, WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)), WithTargetRetrievalFunc(func(ctx context.Context, addr, token string) ([]*targets.Target, error) { require.Equal(t, boundaryAddr, addr) @@ -307,7 +311,7 @@ func TestRefresh(t *testing.T) { assert.ElementsMatch(t, retTargets, cachedTargets) t.Run("empty response clears it out", func(t *testing.T) { - assert.NoError(t, r.Refresh(ctx, + assert.NoError(t, rs.Refresh(ctx, WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)), WithTargetRetrievalFunc(func(ctx context.Context, addr, token string) ([]*targets.Target, error) { require.Equal(t, boundaryAddr, addr) @@ -327,7 +331,7 @@ func TestRefresh(t *testing.T) { session("2"), session("3"), } - assert.NoError(t, r.Refresh(ctx, + assert.NoError(t, rs.Refresh(ctx, WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil)), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(retSess)))) @@ -336,7 +340,7 @@ func TestRefresh(t *testing.T) { assert.ElementsMatch(t, retSess, cachedSessions) t.Run("empty response clears it out", func(t *testing.T) { - assert.NoError(t, r.Refresh(ctx, + assert.NoError(t, rs.Refresh(ctx, WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil)), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)))) @@ -347,8 +351,8 @@ func TestRefresh(t *testing.T) { }) t.Run("error propogates up", func(t *testing.T) { - innerErr := stdErrors.New("test error") - err := r.Refresh(ctx, + innerErr := errors.New("test error") + err := rs.Refresh(ctx, WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)), WithTargetRetrievalFunc(func(ctx context.Context, addr, token string) ([]*targets.Target, error) { require.Equal(t, boundaryAddr, addr) @@ -356,7 +360,7 @@ func TestRefresh(t *testing.T) { return nil, innerErr })) assert.ErrorContains(t, err, innerErr.Error()) - err = r.Refresh(ctx, + err = rs.Refresh(ctx, WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil)), WithSessionRetrievalFunc(func(ctx context.Context, addr, token string) ([]*sessions.Session, error) { require.Equal(t, boundaryAddr, addr) @@ -380,7 +384,7 @@ func TestRefresh(t *testing.T) { require.NoError(t, err) assert.Len(t, us, 1) - r.Refresh(ctx, + rs.Refresh(ctx, WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)), WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil))) diff --git a/internal/daemon/cache/repository.go b/internal/cache/repository.go similarity index 91% rename from internal/daemon/cache/repository.go rename to internal/cache/repository.go index cf3ee618fed..7fd528b7f02 100644 --- a/internal/daemon/cache/repository.go +++ b/internal/cache/repository.go @@ -34,11 +34,11 @@ type Repository struct { } // NewRepository returns a cache repository. The provided Store must be -func NewRepository(ctx context.Context, s *Store, idToAuthToken *sync.Map, keyringFn KeyringTokenLookupFn, atReadFn BoundaryTokenReaderFn, opt ...Option) (*Repository, error) { +func NewRepository(ctx context.Context, conn *db.DB, idToAuthToken *sync.Map, keyringFn KeyringTokenLookupFn, atReadFn BoundaryTokenReaderFn, opt ...Option) (*Repository, error) { const op = "cache.NewRepository" switch { - case util.IsNil(s): - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing store") + case util.IsNil(conn): + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing store connection") case util.IsNil(idToAuthToken): return nil, errors.New(ctx, errors.InvalidParameter, op, "missing keyringless auth token map") case util.IsNil(keyringFn): @@ -47,7 +47,7 @@ func NewRepository(ctx context.Context, s *Store, idToAuthToken *sync.Map, keyri return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth token read function") } return &Repository{ - rw: db.New(s.conn), + rw: db.New(conn), tokenKeyringFn: keyringFn, tokenReadFromBoundaryFn: atReadFn, // This is passed in instead of being fully owned by the repo so multiple diff --git a/internal/daemon/cache/repository_sessions.go b/internal/cache/repository_sessions.go similarity index 100% rename from internal/daemon/cache/repository_sessions.go rename to internal/cache/repository_sessions.go diff --git a/internal/daemon/cache/repository_sessions_test.go b/internal/cache/repository_sessions_test.go similarity index 97% rename from internal/daemon/cache/repository_sessions_test.go rename to internal/cache/repository_sessions_test.go index f192c787917..ae0b4e67af6 100644 --- a/internal/daemon/cache/repository_sessions_test.go +++ b/internal/cache/repository_sessions_test.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/boundary/api/authtokens" "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" + cachedb "github.com/hashicorp/boundary/internal/cache/db" "github.com/hashicorp/boundary/internal/daemon/controller" "github.com/hashicorp/boundary/internal/daemon/worker" "github.com/hashicorp/boundary/internal/db" @@ -23,7 +24,7 @@ import ( func TestRepository_refreshSessions(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" @@ -118,7 +119,7 @@ func TestRepository_refreshSessions(t *testing.T) { WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(tc.sess))) if tc.errorContains == "" { assert.NoError(t, err) - rw := db.New(s.conn) + rw := db.New(s) var got []*Session require.NoError(t, rw.SearchWhere(ctx, &got, "true", nil)) assert.Len(t, got, tc.wantCount) @@ -131,7 +132,7 @@ func TestRepository_refreshSessions(t *testing.T) { func TestRepository_ListSessions(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" @@ -216,7 +217,7 @@ func TestRepository_ListSessions(t *testing.T) { func TestRepository_QuerySessions(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" diff --git a/internal/daemon/cache/repository_targets.go b/internal/cache/repository_targets.go similarity index 100% rename from internal/daemon/cache/repository_targets.go rename to internal/cache/repository_targets.go diff --git a/internal/daemon/cache/repository_targets_test.go b/internal/cache/repository_targets_test.go similarity index 98% rename from internal/daemon/cache/repository_targets_test.go rename to internal/cache/repository_targets_test.go index 5539553f758..8046a6f789a 100644 --- a/internal/daemon/cache/repository_targets_test.go +++ b/internal/cache/repository_targets_test.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/boundary/api/authtokens" "github.com/hashicorp/boundary/api/targets" + cachedb "github.com/hashicorp/boundary/internal/cache/db" "github.com/hashicorp/boundary/internal/daemon/controller" "github.com/hashicorp/boundary/internal/db" "github.com/stretchr/testify/assert" @@ -19,7 +20,7 @@ import ( func TestRepository_refreshTargets(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" @@ -112,7 +113,7 @@ func TestRepository_refreshTargets(t *testing.T) { WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(tc.targets))) if tc.errorContains == "" { assert.NoError(t, err) - rw := db.New(s.conn) + rw := db.New(s) var got []*Target require.NoError(t, rw.SearchWhere(ctx, &got, "true", nil)) assert.Len(t, got, tc.wantCount) @@ -125,7 +126,7 @@ func TestRepository_refreshTargets(t *testing.T) { func TestRepository_ListTargets(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" @@ -206,7 +207,7 @@ func TestRepository_ListTargets(t *testing.T) { func TestRepository_QueryTargets(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" diff --git a/internal/daemon/cache/repository_test.go b/internal/cache/repository_test.go similarity index 96% rename from internal/daemon/cache/repository_test.go rename to internal/cache/repository_test.go index 5cc7403bd7f..b2bd4582e8c 100644 --- a/internal/daemon/cache/repository_test.go +++ b/internal/cache/repository_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/hashicorp/boundary/api/authtokens" + cachedb "github.com/hashicorp/boundary/internal/cache/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -45,7 +46,7 @@ func sliceBasedAuthTokenBoundaryReader(s []*authtokens.AuthToken) BoundaryTokenR func TestRepository_SaveError(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) r, err := NewRepository(ctx, s, &sync.Map{}, diff --git a/internal/daemon/cache/repository_token.go b/internal/cache/repository_token.go similarity index 100% rename from internal/daemon/cache/repository_token.go rename to internal/cache/repository_token.go diff --git a/internal/daemon/cache/repository_token_test.go b/internal/cache/repository_token_test.go similarity index 97% rename from internal/daemon/cache/repository_token_test.go rename to internal/cache/repository_token_test.go index c98ea438e96..54a1aceb7d8 100644 --- a/internal/daemon/cache/repository_token_test.go +++ b/internal/cache/repository_token_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/hashicorp/boundary/api/authtokens" + cachedb "github.com/hashicorp/boundary/internal/cache/db" "github.com/hashicorp/boundary/internal/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,7 +20,7 @@ import ( func TestRepository_AddKeyringToken(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" @@ -150,7 +151,7 @@ func TestRepository_AddKeyringToken(t *testing.T) { func TestRepository_AddRawToken(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) at := &authtokens.AuthToken{ @@ -230,7 +231,7 @@ func TestRepository_AddRawToken(t *testing.T) { func TestRepository_AddToken_EvictsOverLimitUsers(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) boundaryAuthTokens := []*authtokens.AuthToken{ @@ -288,7 +289,7 @@ func TestRepository_AddToken_EvictsOverLimitUsers(t *testing.T) { func TestRepository_AddToken_EvictsOverLimit_Keyringless(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) boundaryAuthTokens := []*authtokens.AuthToken{ @@ -340,7 +341,7 @@ func TestRepository_AddToken_EvictsOverLimit_Keyringless(t *testing.T) { func TestRepository_CleanAuthTokens(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) at := &authtokens.AuthToken{ @@ -371,7 +372,7 @@ func TestRepository_CleanAuthTokens(t *testing.T) { func TestRepository_AddKeyringToken_AddingExistingUpdatesLastAccessedTime(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" @@ -422,7 +423,7 @@ func TestRepository_AddKeyringToken_AddingExistingUpdatesLastAccessedTime(t *tes func TestRepository_AddRawToken_AddingExistingUpdatesLastAccessedTime(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" @@ -460,7 +461,7 @@ func TestRepository_AddRawToken_AddingExistingUpdatesLastAccessedTime(t *testing func TestRepository_ListTokens(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" @@ -513,7 +514,7 @@ func TestRepository_ListTokens(t *testing.T) { func TestRepository_DeleteKeyringToken(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" @@ -555,7 +556,7 @@ func TestRepository_DeleteKeyringToken(t *testing.T) { func TestRepository_LookupToken(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) addr := "address" @@ -611,7 +612,7 @@ func TestRepository_LookupToken(t *testing.T) { func TestRepository_RemoveStaleTokens(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) u := &user{ @@ -707,9 +708,9 @@ func TestCleanExpiredOrOrphanedAuthTokens_Errors(t *testing.T) { assert.Error(t, err) assert.ErrorContains(t, err, "writer is nil") - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) - rw := db.New(s.conn) + rw := db.New(s) err = cleanExpiredOrOrphanedAuthTokens(ctx, rw, &sync.Map{}) assert.Error(t, err) @@ -726,9 +727,9 @@ func TestCleanExpiredOrOrphanedAuthTokens_Errors(t *testing.T) { func TestUpsertUserAndAuthToken(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) - rw := db.New(s.conn) + rw := db.New(s) defaultAt := &authtokens.AuthToken{ Id: "at_123", diff --git a/internal/daemon/cache/search.go b/internal/cache/search.go similarity index 100% rename from internal/daemon/cache/search.go rename to internal/cache/search.go diff --git a/internal/daemon/cache/search_test.go b/internal/cache/search_test.go similarity index 97% rename from internal/daemon/cache/search_test.go rename to internal/cache/search_test.go index 7465b244c6b..bce918c06b4 100644 --- a/internal/daemon/cache/search_test.go +++ b/internal/cache/search_test.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" + cachedb "github.com/hashicorp/boundary/internal/cache/db" "github.com/hashicorp/boundary/internal/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -26,7 +27,7 @@ func TestNewSearchService(t *testing.T) { }) t.Run("success", func(t *testing.T) { - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(nil), @@ -41,7 +42,7 @@ func TestNewSearchService(t *testing.T) { func TestSearch_Errors(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(nil), @@ -104,7 +105,7 @@ func TestSearch_Errors(t *testing.T) { func TestSearch(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) at := &AuthToken{ @@ -113,7 +114,7 @@ func TestSearch(t *testing.T) { } { u := &user{Id: at.UserId, Address: "address"} - rw := db.New(s.conn) + rw := db.New(s) require.NoError(t, rw.Create(ctx, u)) require.NoError(t, rw.Create(ctx, at)) diff --git a/internal/daemon/cache/store_test.go b/internal/cache/store_test.go similarity index 96% rename from internal/daemon/cache/store_test.go rename to internal/cache/store_test.go index e5a1b12caed..6e6987d3b30 100644 --- a/internal/daemon/cache/store_test.go +++ b/internal/cache/store_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + cachedb "github.com/hashicorp/boundary/internal/cache/db" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" "github.com/stretchr/testify/assert" @@ -16,9 +17,9 @@ import ( func TestUser(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + conn, err := cachedb.Open(ctx) require.NoError(t, err) - rw := db.New(s.conn) + rw := db.New(conn) t.Run("missing address", func(t *testing.T) { u := &user{ @@ -61,9 +62,9 @@ func TestUser(t *testing.T) { func TestUser_NoMoreTokens(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + conn, err := cachedb.Open(ctx) require.NoError(t, err) - rw := db.New(s.conn) + rw := db.New(conn) u := &user{ Id: "userId", @@ -96,9 +97,9 @@ func TestUser_NoMoreTokens(t *testing.T) { func TestAuthToken_NoMoreKeyringTokens(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) - rw := db.New(s.conn) + rw := db.New(s) u := &user{ Id: "userId", @@ -135,10 +136,10 @@ func TestAuthToken_NoMoreKeyringTokens(t *testing.T) { func TestAuthToken(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) - rw := db.New(s.conn) + rw := db.New(s) u := &user{ Id: "userId", @@ -212,10 +213,9 @@ func TestAuthToken(t *testing.T) { func TestKeyringToken(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) - - rw := db.New(s.conn) + rw := db.New(s) u := &user{ Id: "userId", @@ -315,10 +315,10 @@ func TestKeyringToken(t *testing.T) { func TestTarget(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) - rw := db.New(s.conn) + rw := db.New(s) addr := "boundary" userId := "u_12345" @@ -417,10 +417,10 @@ func TestTarget(t *testing.T) { func TestSession(t *testing.T) { ctx := context.Background() - s, err := Open(ctx) + s, err := cachedb.Open(ctx) require.NoError(t, err) - rw := db.New(s.conn) + rw := db.New(s) addr := "boundary" userId := "u_12345" diff --git a/internal/cmd/commands/daemon/addtoken.go b/internal/cmd/commands/daemon/addtoken.go index 167fab95064..ac3dd8334c4 100644 --- a/internal/cmd/commands/daemon/addtoken.go +++ b/internal/cmd/commands/daemon/addtoken.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/internal/cmd/base" + "github.com/hashicorp/boundary/internal/daemon/cache" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/version" "github.com/mitchellh/cli" @@ -89,7 +90,7 @@ func (c *AddTokenCommand) Add(ctx context.Context) (*api.Error, error) { return nil, err } - pa := upsertTokenRequest{ + pa := cache.UpsertTokenRequest{ BoundaryAddr: client.Addr(), } switch keyringType { @@ -107,7 +108,7 @@ func (c *AddTokenCommand) Add(ctx context.Context) (*api.Error, error) { if at == nil { return nil, errors.New(ctx, errors.Conflict, op, "no auth token available to send to daemon") } - pa.Keyring = &keyringToken{ + pa.Keyring = &cache.KeyringToken{ KeyringType: keyringType, TokenName: tokenName, } @@ -122,13 +123,13 @@ func (c *AddTokenCommand) Add(ctx context.Context) (*api.Error, error) { return addToken(ctx, dotPath, &pa) } -func addToken(ctx context.Context, daemonPath string, p *upsertTokenRequest) (*api.Error, error) { +func addToken(ctx context.Context, daemonPath string, p *cache.UpsertTokenRequest) (*api.Error, error) { const op = "daemon.addToken" client, err := api.NewClient(nil) if err != nil { return nil, errors.Wrap(ctx, err, op) } - addr := SocketAddress(daemonPath) + addr := cache.SocketAddress(daemonPath) _, err = os.Stat(strings.TrimPrefix(addr, "unix://")) if strings.HasPrefix(addr, "unix://") && err != nil { return nil, errors.Wrap(ctx, err, op) @@ -144,7 +145,7 @@ func addToken(ctx context.Context, daemonPath string, p *upsertTokenRequest) (*a if err != nil { return nil, err } - req.Header.Add(VersionHeaderKey, version.Get().VersionNumber()) + req.Header.Add(cache.VersionHeaderKey, version.Get().VersionNumber()) resp, err := client.Do(req) if err != nil { return nil, err diff --git a/internal/cmd/commands/daemon/command_wrapper.go b/internal/cmd/commands/daemon/command_wrapper.go index ec0b8d8759c..57a87209661 100644 --- a/internal/cmd/commands/daemon/command_wrapper.go +++ b/internal/cmd/commands/daemon/command_wrapper.go @@ -15,25 +15,19 @@ import ( "github.com/mitchellh/cli" ) -// wrappableCommand defines the interface for the commands that can be wrapped. -type wrappableCommand interface { - cli.Command - Commander -} - // CommandWrapper starts the boundary daemon after the command was Run and attempts // to send the current persona to any running daemon. type CommandWrapper struct { - wrappableCommand + cli.Command ui cli.Ui } // Wrap returns a cli.CommandFactory that returns a command wrapped in the CommandWrapper. -func Wrap(ui cli.Ui, wrapped wrappableCommand) cli.CommandFactory { +func Wrap(ui cli.Ui, wrapped cli.Command) cli.CommandFactory { return func() (cli.Command, error) { return &CommandWrapper{ - wrappableCommand: wrapped, - ui: ui, + Command: wrapped, + ui: ui, }, nil } } @@ -41,7 +35,7 @@ func Wrap(ui cli.Ui, wrapped wrappableCommand) cli.CommandFactory { // Run runs the wrapped command and then attempts to start the boundary daemon and send // the current persona func (w *CommandWrapper) Run(args []string) int { - r := w.wrappableCommand.Run(args) + r := w.Command.Run(args) ctx := context.Background() if w.startDaemon(ctx) { diff --git a/internal/cmd/commands/daemon/start.go b/internal/cmd/commands/daemon/start.go index 2f6848f6220..65205b38b87 100644 --- a/internal/cmd/commands/daemon/start.go +++ b/internal/cmd/commands/daemon/start.go @@ -15,14 +15,13 @@ import ( "sync" "github.com/hashicorp/boundary/internal/cmd/base" + "github.com/hashicorp/boundary/internal/daemon/cache" "github.com/hashicorp/boundary/internal/errors" "github.com/mitchellh/cli" "github.com/mitchellh/go-homedir" "github.com/posener/complete" ) -const DefaultRefreshIntervalSeconds = 5 * 60 - const ( dotDirname = ".boundary" pidFileName = "cache.pid" @@ -40,7 +39,7 @@ var ( type server interface { setupLogging(context.Context, io.Writer) error - serve(context.Context, Commander, net.Listener) error + serve(context.Context, cache.Commander, net.Listener) error shutdown() error } @@ -101,7 +100,7 @@ func (c *StartCommand) Flags() *base.FlagSets { Target: &c.flagRefreshIntervalSeconds, Usage: `If set, specifies the number of seconds between cache refreshes. Default: 5 minutes`, Aliases: []string{"r"}, - Default: DefaultRefreshIntervalSeconds, + Default: cache.DefaultRefreshIntervalSeconds, }) f.BoolVar(&base.BoolVar{ Name: "store-debug", @@ -175,39 +174,34 @@ func (c *StartCommand) Run(args []string) int { } writers = append(writers, logFile) - cfg := &serverConfig{ - contextCancel: cancel, - refreshIntervalSeconds: c.flagRefreshIntervalSeconds, - flagDatabaseUrl: c.flagDatabaseUrl, - flagStoreDebug: c.flagStoreDebug, - flagLogLevel: c.flagLogLevel, - flagLogFormat: c.flagLogFormat, - logWriter: io.MultiWriter(writers...), + cfg := &cache.Config{ + ContextCancel: cancel, + RefreshIntervalSeconds: c.flagRefreshIntervalSeconds, + DatabaseUrl: c.flagDatabaseUrl, + StoreDebug: c.flagStoreDebug, + LogLevel: c.flagLogLevel, + LogFormat: c.flagLogFormat, + LogWriter: io.MultiWriter(writers...), } - srv, err := newServer(ctx, cfg) + srv, err := cache.New(ctx, cfg) if err != nil { c.UI.Error(err.Error()) return base.CommandUserError } - l, err := listener(ctx, dotDir) - if err != nil { - c.PrintCliError(err) - return base.CommandCliError - } var srvErr error var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - srvErr = srv.serve(ctx, c, l) + srvErr = srv.Serve(ctx, c) }() // This is a blocking call. We rely on the c.ShutdownCh to cancel this // context when sigterm or sigint is received. <-ctx.Done() - if err := srv.shutdown(ctx); err != nil { + if err := srv.Shutdown(ctx); err != nil { c.PrintCliError(err) return base.CommandCliError } diff --git a/internal/cmd/commands/daemon/token_test.go b/internal/cmd/commands/daemon/token_test.go index 1e40ea90c88..5e9ae57883b 100644 --- a/internal/cmd/commands/daemon/token_test.go +++ b/internal/cmd/commands/daemon/token_test.go @@ -5,14 +5,16 @@ package daemon import ( "context" - stdErrors "errors" + "errors" "net/http" "sync" "testing" "github.com/hashicorp/boundary/api/authtokens" + "github.com/hashicorp/boundary/internal/cache" + cachedb "github.com/hashicorp/boundary/internal/cache/db" "github.com/hashicorp/boundary/internal/cmd/base" - "github.com/hashicorp/boundary/internal/daemon/cache" + cachedaemon "github.com/hashicorp/boundary/internal/daemon/cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -41,7 +43,7 @@ func sliceBasedAuthTokenBoundaryReader(s []*authtokens.AuthToken) cache.Boundary return v, nil } } - return nil, stdErrors.New("not found") + return nil, errors.New("not found") } } @@ -54,8 +56,11 @@ func (r *testRefresher) refresh() { } func TestKeyringToken(t *testing.T) { + // TODO: for this to compile this needs to be refactored to look closer + // to internal/cmd/commands/search/search_test.go instead of referencing the + // handler logic directly ctx := context.Background() - s, _, err := openStore(ctx, "", false) + s, err := cachedb.Open(ctx) require.NoError(t, err) at := &authtokens.AuthToken{ @@ -73,14 +78,14 @@ func TestKeyringToken(t *testing.T) { require.NoError(t, err) tr := &testRefresher{} - ph, err := newTokenHandlerFunc(ctx, r, tr) - require.NoError(t, err) + // ph, err := newTokenHandlerFunc(ctx, r, tr) + // require.NoError(t, err) mux := http.NewServeMux() - mux.HandleFunc("/v1/tokens", ph) + // mux.HandleFunc("/v1/tokens", ph) tmpdir := t.TempDir() - l, err := listener(ctx, tmpdir) + l, err := cachedaemon.Listener(ctx, tmpdir) require.NoError(t, err) srv := &http.Server{ Handler: mux, @@ -93,8 +98,8 @@ func TestKeyringToken(t *testing.T) { }() t.Run("missing keyring", func(t *testing.T) { - pa := &upsertTokenRequest{ - Keyring: &keyringToken{ + pa := &cachedaemon.UpsertTokenRequest{ + Keyring: &cachedaemon.KeyringToken{ KeyringType: "", TokenName: tokenName, }, @@ -109,8 +114,8 @@ func TestKeyringToken(t *testing.T) { }) t.Run("none keyring", func(t *testing.T) { - pa := &upsertTokenRequest{ - Keyring: &keyringToken{ + pa := &cachedaemon.UpsertTokenRequest{ + Keyring: &cachedaemon.KeyringToken{ KeyringType: base.NoneKeyring, TokenName: tokenName, }, @@ -125,8 +130,8 @@ func TestKeyringToken(t *testing.T) { }) t.Run("missing token name", func(t *testing.T) { - pa := &upsertTokenRequest{ - Keyring: &keyringToken{ + pa := &cachedaemon.UpsertTokenRequest{ + Keyring: &cachedaemon.KeyringToken{ KeyringType: keyring, TokenName: "", }, @@ -141,8 +146,8 @@ func TestKeyringToken(t *testing.T) { }) t.Run("missing boundary address", func(t *testing.T) { - pa := &upsertTokenRequest{ - Keyring: &keyringToken{ + pa := &cachedaemon.UpsertTokenRequest{ + Keyring: &cachedaemon.KeyringToken{ KeyringType: keyring, TokenName: tokenName, }, @@ -157,8 +162,8 @@ func TestKeyringToken(t *testing.T) { }) t.Run("missing auth token id", func(t *testing.T) { - pa := &upsertTokenRequest{ - Keyring: &keyringToken{ + pa := &cachedaemon.UpsertTokenRequest{ + Keyring: &cachedaemon.KeyringToken{ KeyringType: keyring, TokenName: tokenName, }, @@ -173,8 +178,8 @@ func TestKeyringToken(t *testing.T) { }) t.Run("mismatched auth token id", func(t *testing.T) { - pa := &upsertTokenRequest{ - Keyring: &keyringToken{ + pa := &cachedaemon.UpsertTokenRequest{ + Keyring: &cachedaemon.KeyringToken{ KeyringType: keyring, TokenName: tokenName, }, @@ -189,8 +194,8 @@ func TestKeyringToken(t *testing.T) { }) t.Run("success", func(t *testing.T) { - pa := &upsertTokenRequest{ - Keyring: &keyringToken{ + pa := &cachedaemon.UpsertTokenRequest{ + Keyring: &cachedaemon.KeyringToken{ KeyringType: keyring, TokenName: tokenName, }, @@ -213,7 +218,7 @@ func TestKeyringToken(t *testing.T) { func TestKeyringlessToken(t *testing.T) { ctx := context.Background() - s, _, err := openStore(ctx, "", false) + s, err := cachedb.Open(ctx) require.NoError(t, err) at := &authtokens.AuthToken{ @@ -226,15 +231,18 @@ func TestKeyringlessToken(t *testing.T) { r, err := cache.NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) require.NoError(t, err) + // TODO: for this to compile this needs to be refactored to look closer + // to internal/cmd/commands/search/search_test.go instead of referencing the + // handler logic directly tr := &testRefresher{} - ph, err := newTokenHandlerFunc(ctx, r, tr) - require.NoError(t, err) + //ph, err := newTokenHandlerFunc(ctx, r, tr) + //require.NoError(t, err) mux := http.NewServeMux() - mux.HandleFunc("/v1/tokens", ph) + //mux.HandleFunc("/v1/tokens", ph) tmpdir := t.TempDir() - l, err := listener(ctx, tmpdir) + l, err := cachedaemon.Listener(ctx, tmpdir) require.NoError(t, err) srv := &http.Server{ Handler: mux, @@ -247,7 +255,7 @@ func TestKeyringlessToken(t *testing.T) { }() t.Run("missing boundary address", func(t *testing.T) { - pa := &upsertTokenRequest{ + pa := &cachedaemon.UpsertTokenRequest{ BoundaryAddr: "", AuthTokenId: at.Id, AuthToken: at.Token, @@ -260,7 +268,7 @@ func TestKeyringlessToken(t *testing.T) { }) t.Run("missing auth token id", func(t *testing.T) { - pa := &upsertTokenRequest{ + pa := &cachedaemon.UpsertTokenRequest{ BoundaryAddr: "http://127.0.0.1", AuthTokenId: "", AuthToken: at.Token, @@ -273,7 +281,7 @@ func TestKeyringlessToken(t *testing.T) { }) t.Run("mismatched auth token id", func(t *testing.T) { - pa := &upsertTokenRequest{ + pa := &cachedaemon.UpsertTokenRequest{ BoundaryAddr: "http://127.0.0.1", AuthTokenId: "at_doesntmatch", AuthToken: at.Token, @@ -286,7 +294,7 @@ func TestKeyringlessToken(t *testing.T) { }) t.Run("success", func(t *testing.T) { - pa := &upsertTokenRequest{ + pa := &cachedaemon.UpsertTokenRequest{ BoundaryAddr: "http://127.0.0.1", AuthTokenId: at.Id, AuthToken: at.Token, diff --git a/internal/cmd/commands/search/search.go b/internal/cmd/commands/search/search.go index 7c242463cad..d18483299a8 100644 --- a/internal/cmd/commands/search/search.go +++ b/internal/cmd/commands/search/search.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/cmd/commands/daemon" + "github.com/hashicorp/boundary/internal/daemon/cache" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/version" "github.com/mitchellh/cli" @@ -114,7 +115,7 @@ func (c *SearchCommand) Run(args []string) int { c.PrintCliError(err) return base.CommandCliError } - res := &daemon.SearchResult{} + res := &cache.SearchResult{} apiErr, err := resp.Decode(res) if err != nil { c.PrintCliError(err) @@ -168,7 +169,7 @@ func search(ctx context.Context, daemonPath string, fb filterBy) (*api.Response, if err != nil { return nil, errors.Wrap(ctx, err, op) } - addr := daemon.SocketAddress(daemonPath) + addr := cache.SocketAddress(daemonPath) if err != nil { return nil, errors.Wrap(ctx, err, op) } @@ -183,7 +184,7 @@ func search(ctx context.Context, daemonPath string, fb filterBy) (*api.Response, if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("new client request error")) } - req.Header.Add(daemon.VersionHeaderKey, version.Get().VersionNumber()) + req.Header.Add(cache.VersionHeaderKey, version.Get().VersionNumber()) q := url.Values{} q.Add("auth_token_id", fb.authTokenId) q.Add("resource", fb.resource) diff --git a/internal/cmd/commands/search/search_test.go b/internal/cmd/commands/search/search_test.go index 497388fdb89..9bba9e289d4 100644 --- a/internal/cmd/commands/search/search_test.go +++ b/internal/cmd/commands/search/search_test.go @@ -15,7 +15,7 @@ import ( "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/internal/cmd/base" - "github.com/hashicorp/boundary/internal/cmd/commands/daemon" + "github.com/hashicorp/boundary/internal/daemon/cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -63,12 +63,12 @@ func TestSearch(t *testing.T) { return nil, errors.New("test not found error") } - srv := daemon.NewTestServer(t, cmd) + srv := cache.NewTestServer(t, cmd) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - srv.Serve(t, daemon.WithBoundaryTokenReaderFunc(ctx, boundaryTokenReaderFn)) + srv.Serve(t, cache.WithBoundaryTokenReaderFunc(ctx, boundaryTokenReaderFn)) }() // Give the store some time to get initialized time.Sleep(100 * time.Millisecond) @@ -119,7 +119,7 @@ func TestSearch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { resp, err := search(ctx, srv.BaseSocketDir(), tc.fb) require.NoError(t, err) - r := daemon.SearchResult{} + r := cache.SearchResult{} apiErr, err := resp.Decode(&r) assert.NoError(t, err) assert.NotNil(t, apiErr) @@ -133,12 +133,12 @@ func TestSearch(t *testing.T) { resource: "targets", }) require.NoError(t, err) - r := daemon.SearchResult{} + r := cache.SearchResult{} apiErr, err := resp.Decode(&r) assert.NoError(t, err) assert.Nil(t, apiErr) assert.NotNil(t, r) - assert.EqualValues(t, r, daemon.SearchResult{}) + assert.EqualValues(t, r, cache.SearchResult{}) }) t.Run("empty response from query", func(t *testing.T) { @@ -148,12 +148,12 @@ func TestSearch(t *testing.T) { resource: "targets", }) require.NoError(t, err) - r := daemon.SearchResult{} + r := cache.SearchResult{} apiErr, err := resp.Decode(&r) assert.NoError(t, err) assert.Nil(t, apiErr) assert.NotNil(t, r) - assert.EqualValues(t, r, daemon.SearchResult{}) + assert.EqualValues(t, r, cache.SearchResult{}) }) srv.AddResources(t, cmd.at, []*targets.Target{ @@ -170,7 +170,7 @@ func TestSearch(t *testing.T) { resource: "targets", }) require.NoError(t, err) - r := daemon.SearchResult{} + r := cache.SearchResult{} apiErr, err := resp.Decode(&r) assert.NoError(t, err) assert.Nil(t, apiErr) @@ -184,7 +184,7 @@ func TestSearch(t *testing.T) { resource: "targets", }) require.NoError(t, err) - r := daemon.SearchResult{} + r := cache.SearchResult{} apiErr, err := resp.Decode(&r) assert.NoError(t, err) assert.Nil(t, apiErr) @@ -198,7 +198,7 @@ func TestSearch(t *testing.T) { resource: "targets", }) require.NoError(t, err) - r := daemon.SearchResult{} + r := cache.SearchResult{} apiErr, err := resp.Decode(&r) assert.NoError(t, err) assert.Nil(t, apiErr) @@ -213,7 +213,7 @@ func TestSearch(t *testing.T) { resource: "sessions", }) require.NoError(t, err) - r := daemon.SearchResult{} + r := cache.SearchResult{} apiErr, err := resp.Decode(&r) assert.NoError(t, err) assert.Nil(t, apiErr) @@ -228,7 +228,7 @@ func TestSearch(t *testing.T) { resource: "sessions", }) require.NoError(t, err) - r := daemon.SearchResult{} + r := cache.SearchResult{} apiErr, err := resp.Decode(&r) assert.NoError(t, err) assert.Nil(t, apiErr) @@ -242,7 +242,7 @@ func TestSearch(t *testing.T) { resource: "sessions", }) require.NoError(t, err) - r := daemon.SearchResult{} + r := cache.SearchResult{} apiErr, err := resp.Decode(&r) assert.NoError(t, err) assert.Nil(t, apiErr) diff --git a/internal/cmd/commands/daemon/handlers.go b/internal/daemon/cache/handlers.go similarity index 97% rename from internal/cmd/commands/daemon/handlers.go rename to internal/daemon/cache/handlers.go index 3562ecfe5e0..fae6d48b4bd 100644 --- a/internal/cmd/commands/daemon/handlers.go +++ b/internal/daemon/cache/handlers.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "encoding/json" diff --git a/internal/cmd/commands/daemon/listener.go b/internal/daemon/cache/listener.go similarity index 92% rename from internal/cmd/commands/daemon/listener.go rename to internal/daemon/cache/listener.go index 12b97b2aae4..b5c48c569bc 100644 --- a/internal/cmd/commands/daemon/listener.go +++ b/internal/daemon/cache/listener.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" @@ -26,8 +26,8 @@ const ( socketPerms = 0o600 ) -// listener provides a listener on the daemon unix socket. -func listener(ctx context.Context, path string) (net.Listener, error) { +// Listener provides a listener on the daemon unix socket. +func Listener(ctx context.Context, path string) (net.Listener, error) { const op = "daemon.listener" socketName := filepath.Join(path, sockAddr) if err := os.Remove(socketName); err != nil { diff --git a/internal/cmd/commands/daemon/listener_nonwindows_test.go b/internal/daemon/cache/listener_nonwindows_test.go similarity index 95% rename from internal/cmd/commands/daemon/listener_nonwindows_test.go rename to internal/daemon/cache/listener_nonwindows_test.go index 20fa117a3f3..1eddbbd99a5 100644 --- a/internal/cmd/commands/daemon/listener_nonwindows_test.go +++ b/internal/daemon/cache/listener_nonwindows_test.go @@ -4,7 +4,7 @@ //go:build !windows // +build !windows -package daemon +package cache import ( "context" @@ -23,7 +23,7 @@ func TestListenerSocketPermissions(t *testing.T) { path, err := os.MkdirTemp("", "*") require.NoError(t, err) - l, err := listener(ctx, path) + l, err := Listener(ctx, path) require.NoError(t, err) socketFile := l.Addr().String() fi, err := os.Stat(socketFile) diff --git a/internal/cmd/commands/daemon/listener_test.go b/internal/daemon/cache/listener_test.go similarity index 95% rename from internal/cmd/commands/daemon/listener_test.go rename to internal/daemon/cache/listener_test.go index bf24618092d..f1353d89d13 100644 --- a/internal/cmd/commands/daemon/listener_test.go +++ b/internal/daemon/cache/listener_test.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" @@ -22,7 +22,7 @@ import ( func TestListenComms(t *testing.T) { ctx := context.Background() path := t.TempDir() - socketListener, err := listener(ctx, path) + socketListener, err := Listener(ctx, path) require.NoError(t, err) mux := http.NewServeMux() diff --git a/internal/daemon/cache/options.go b/internal/daemon/cache/options.go index 7b37febd8a0..0835ba53e22 100644 --- a/internal/daemon/cache/options.go +++ b/internal/daemon/cache/options.go @@ -4,25 +4,21 @@ package cache import ( - "github.com/hashicorp/go-dbw" + "context" + + "github.com/hashicorp/boundary/internal/cache" ) type options struct { - withDebug bool - withUrl string - withUpdateLastAccessedTime bool - withDbType dbw.DbType - withTargetRetrievalFunc TargetRetrievalFunc - withSessionRetrievalFunc SessionRetrievalFunc + withDebug bool + withBoundaryTokenReaderFunc cache.BoundaryTokenReaderFn } // Option - how options are passed as args type Option func(*options) error func getDefaultOptions() options { - return options{ - withDbType: dbw.Sqlite, - } + return options{} } func getOpts(opt ...Option) (options, error) { @@ -36,42 +32,18 @@ func getOpts(opt ...Option) (options, error) { return opts, nil } -// WithUrls provides optional url -func WithUrl(url string) Option { - return func(o *options) error { - o.withUrl = url - return nil - } -} - // WithDebug provides an optional debug flag. -func WithDebug(debug bool) Option { +func WithDebug(_ context.Context, debug bool) Option { return func(o *options) error { o.withDebug = debug return nil } } -// WithUpdateLastAccessedTime provides an option for updating the last access time -func WithUpdateLastAccessedTime(b bool) Option { - return func(o *options) error { - o.withUpdateLastAccessedTime = b - return nil - } -} - -// WithTargetRetrievalFunc provides an option for specifying a targetRetrievalFunc -func WithTargetRetrievalFunc(fn TargetRetrievalFunc) Option { - return func(o *options) error { - o.withTargetRetrievalFunc = fn - return nil - } -} - -// WithSessionRetrievalFunc provides an option for specifying a sessionRetrievalFunc -func WithSessionRetrievalFunc(fn SessionRetrievalFunc) Option { +// WithBoundaryTokenReaderFunc provides an option for specifying a BoundaryTokenReaderFn +func WithBoundaryTokenReaderFunc(_ context.Context, fn cache.BoundaryTokenReaderFn) Option { return func(o *options) error { - o.withSessionRetrievalFunc = fn + o.withBoundaryTokenReaderFunc = fn return nil } } diff --git a/internal/daemon/cache/options_test.go b/internal/daemon/cache/options_test.go index d28769611f6..b688788f4cf 100644 --- a/internal/daemon/cache/options_test.go +++ b/internal/daemon/cache/options_test.go @@ -7,66 +7,38 @@ import ( "context" "testing" - "github.com/hashicorp/boundary/api/sessions" - "github.com/hashicorp/boundary/api/targets" - "github.com/hashicorp/go-dbw" + "github.com/hashicorp/boundary/api/authtokens" + "github.com/hashicorp/boundary/internal/cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_GetOpts(t *testing.T) { t.Parallel() + ctx := context.Background() t.Run("default", func(t *testing.T) { opts, err := getOpts() require.NoError(t, err) - testOpts := options{ - withDbType: dbw.Sqlite, - } - assert.Equal(t, opts, testOpts) - }) - t.Run("WithUrl", func(t *testing.T) { - url := "something" - opts, err := getOpts(WithUrl(url)) - require.NoError(t, err) - testOpts := getDefaultOptions() - testOpts.withUrl = url + testOpts := options{} assert.Equal(t, opts, testOpts) }) t.Run("WithDebug", func(t *testing.T) { - opts, err := getOpts(WithDebug(true)) + opts, err := getOpts(WithDebug(ctx, true)) require.NoError(t, err) testOpts := getDefaultOptions() testOpts.withDebug = true assert.Equal(t, opts, testOpts) }) - t.Run("WithUpdateLastAccessedTime", func(t *testing.T) { - opts, err := getOpts(WithUpdateLastAccessedTime(true)) - require.NoError(t, err) - testOpts := getDefaultOptions() - testOpts.withUpdateLastAccessedTime = true - assert.Equal(t, opts, testOpts) - }) - t.Run("WithTargetRetrievalFunc", func(t *testing.T) { - var f TargetRetrievalFunc = func(ctx context.Context, keyringstring, tokenName string) ([]*targets.Target, error) { return nil, nil } - opts, err := getOpts(WithTargetRetrievalFunc(f)) - require.NoError(t, err) - - assert.NotNil(t, opts.withTargetRetrievalFunc) - opts.withTargetRetrievalFunc = nil - - testOpts := getDefaultOptions() - assert.Equal(t, opts, testOpts) - }) - t.Run("WithSessionRetrievalFunc", func(t *testing.T) { - var f SessionRetrievalFunc = func(ctx context.Context, keyringstring, tokenName string) ([]*sessions.Session, error) { + t.Run("WithBoundaryTokenReaderFunc", func(t *testing.T) { + var f cache.BoundaryTokenReaderFn = func(ctx context.Context, addr, token string) (*authtokens.AuthToken, error) { return nil, nil } - opts, err := getOpts(WithSessionRetrievalFunc(f)) + opts, err := getOpts(WithBoundaryTokenReaderFunc(ctx, f)) require.NoError(t, err) - assert.NotNil(t, opts.withSessionRetrievalFunc) - opts.withSessionRetrievalFunc = nil + assert.NotNil(t, opts.withBoundaryTokenReaderFunc) + opts.withBoundaryTokenReaderFunc = nil testOpts := getDefaultOptions() assert.Equal(t, opts, testOpts) diff --git a/internal/daemon/cache/pidfile_nonwindows.go b/internal/daemon/cache/pidfile_nonwindows.go new file mode 100644 index 00000000000..88e5182ab5c --- /dev/null +++ b/internal/daemon/cache/pidfile_nonwindows.go @@ -0,0 +1,73 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !windows +// +build !windows + +package cache + +import ( + "context" + stderrors "errors" + "os" + + "github.com/hashicorp/boundary/internal/errors" + "github.com/sevlyar/go-daemon" +) + +func writePidFile(ctx context.Context, pidFile string) (pidCleanup, error) { + const op = "daemon.writePidFile" + + // Determine if we should clean up the file after we are done or if + // it should stick around in the case of lock aquision error since this + // file didn't create it. + var pidExists bool + _, err := os.Stat(pidFile) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return noopPidCleanup, errors.Wrap(ctx, err, op) + } + if err == nil { + pidExists = true + } + + f, err := os.OpenFile(pidFile, os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + return noopPidCleanup, errors.Wrap(ctx, err, op, errors.WithMsg("opening file")) + } + closeAndDeleteFileFn := func() error { + err := f.Close() + if !pidExists { + err = stderrors.Join(err, os.Remove(pidFile)) + } + return err + } + l := daemon.NewLockFile(f) + if err := l.Lock(); err != nil { + return closeAndDeleteFileFn, errors.Wrap(ctx, err, op) + } + // Now that we have aquired the lock and verified we own the pid file + // we can remove it always when cleaning up. + unlockAndCleanFn := func() error { + return l.Remove() + } + if err := l.WritePid(); err != nil { + return unlockAndCleanFn, errors.Wrap(ctx, err, op) + } + return unlockAndCleanFn, nil +} + +func pidFileInUse(ctx context.Context, pidFile string) (*os.Process, error) { + const op = "daemon.pidFileInUse" + if pidFile == "" { + return nil, errors.New(ctx, errors.InvalidParameter, op, "pid filename is empty") + } + proc, err := (&daemon.Context{PidFileName: pidFile}).Search() + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, err + } + return proc, nil +} + +type pidCleanup func() error + +var noopPidCleanup pidCleanup = func() error { return nil } diff --git a/internal/daemon/cache/pidfile_test.go b/internal/daemon/cache/pidfile_test.go new file mode 100644 index 00000000000..abe64e34dbb --- /dev/null +++ b/internal/daemon/cache/pidfile_test.go @@ -0,0 +1,61 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "context" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPidInUse(t *testing.T) { + ctx := context.Background() + dotPath := t.TempDir() + pidPath := filepath.Join(dotPath, "boundary.pid") + + used, err := pidFileInUse(ctx, pidPath) + assert.NoError(t, err) + assert.Nil(t, used) + + workingPidCleanup1, err := writePidFile(ctx, pidPath) + assert.NoError(t, err) + + used, err = pidFileInUse(ctx, pidPath) + assert.NoError(t, err) + assert.NotNil(t, used) + + failingPidCleanup, err := writePidFile(ctx, pidPath) + assert.Error(t, err) + + used, err = pidFileInUse(ctx, pidPath) + assert.NoError(t, err) + assert.NotNil(t, used) + + assert.NoError(t, failingPidCleanup()) + + used, err = pidFileInUse(ctx, pidPath) + assert.NoError(t, err) + assert.NotNil(t, used) + + assert.NoError(t, workingPidCleanup1()) + + used, err = pidFileInUse(ctx, pidPath) + assert.NoError(t, err) + assert.Nil(t, used) + + workingPidCleanup2, err := writePidFile(ctx, pidPath) + assert.NoError(t, err) + + used, err = pidFileInUse(ctx, pidPath) + assert.NoError(t, err) + assert.NotNil(t, used) + + assert.NoError(t, workingPidCleanup2()) + + used, err = pidFileInUse(ctx, pidPath) + assert.NoError(t, err) + assert.Nil(t, used) +} diff --git a/internal/daemon/cache/pidfile_windows.go b/internal/daemon/cache/pidfile_windows.go new file mode 100644 index 00000000000..4875a0c9a11 --- /dev/null +++ b/internal/daemon/cache/pidfile_windows.go @@ -0,0 +1,100 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build windows +// +build windows + +package cache + +import ( + "bytes" + "context" + stderrors "errors" + "fmt" + "io" + "os" + "strings" + + "github.com/hashicorp/boundary/internal/errors" + "golang.org/x/sys/windows" +) + +func writePidFile(ctx context.Context, pidFile string) (pidCleanup, error) { + const op = "daemon.writePidFile" + // Create the file for writing, and set shared read so following processes + // cannot open this file for writing. + fd, err := windows.CreateFile(&(windows.StringToUTF16(pidFile)[0]), windows.GENERIC_WRITE, + windows.FILE_SHARE_READ, nil, windows.OPEN_ALWAYS, windows.FILE_ATTRIBUTE_NORMAL, 0) + if err != nil { + return noopPidCleanup, errors.Wrap(ctx, err, op) + } + + cleanup := func() error { + var err error + if inErr := windows.CloseHandle(fd); inErr != nil { + err = stderrors.Join(err, errors.Wrap(ctx, inErr, op, errors.WithMsg("handler close"))) + } + if inErr := windows.DeleteFile(&(windows.StringToUTF16(pidFile)[0])); inErr != nil { + err = stderrors.Join(err, errors.Wrap(ctx, inErr, op, errors.WithMsg("removing file"))) + } + return err + } + + if _, err := windows.Seek(fd, 0, windows.FILE_BEGIN); err != nil { + return cleanup, errors.Wrap(ctx, err, op) + } + b := bytes.NewBuffer(nil) + if _, err := fmt.Fprint(b, os.Getpid()); err != nil { + return cleanup, errors.Wrap(ctx, err, op, errors.WithMsg("writing file buffer")) + } + var fileLen int + if fileLen, err = windows.Write(fd, b.Bytes()); err != nil { + return cleanup, errors.Wrap(ctx, err, op, errors.WithMsg("writing buffer to file")) + } + if err = windows.Ftruncate(fd, int64(fileLen)); err != nil { + return cleanup, errors.Wrap(ctx, err, op) + } + + return cleanup, windows.Fsync(fd) +} + +func pidFileInUse(ctx context.Context, pidFile string) (*os.Process, error) { + const op = "daemon.pidFileInUse" + if pidFile == "" { + return nil, errors.New(ctx, errors.InvalidParameter, op, "pid filename is empty") + } + + var err error + var file *os.File + if file, err = os.OpenFile(pidFile, os.O_RDONLY, 0o640); err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("OpenFile")) + } + if file == nil { + return nil, nil + } + defer func() { + file.Close() + }() + + if _, err = file.Seek(0, io.SeekStart); err != nil { + return nil, err + } + var pid int + _, err = fmt.Fscan(file, &pid) + if err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("Fscan")) + } + + p, err := os.FindProcess(pid) + if err != nil { + if strings.Contains(err.Error(), "The parameter is incorrect") { + return nil, errors.New(ctx, errors.NotFound, op, "cannot find process") + } + // we failed to get the process for whatever reason + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("FindProcess %d", pid)) + } + if p == nil { + return nil, errors.New(ctx, errors.NotFound, op, "cannot find process") + } + return p, nil +} diff --git a/internal/daemon/cache/query.go b/internal/daemon/cache/query.go deleted file mode 100644 index 55358d2395f..00000000000 --- a/internal/daemon/cache/query.go +++ /dev/null @@ -1,4 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -package cache diff --git a/internal/cmd/commands/daemon/search_handler.go b/internal/daemon/cache/search_handler.go similarity index 97% rename from internal/cmd/commands/daemon/search_handler.go rename to internal/daemon/cache/search_handler.go index e6722b4d874..2c88e41abd7 100644 --- a/internal/cmd/commands/daemon/search_handler.go +++ b/internal/daemon/cache/search_handler.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" @@ -11,7 +11,7 @@ import ( "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" - "github.com/hashicorp/boundary/internal/daemon/cache" + "github.com/hashicorp/boundary/internal/cache" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/util" ) diff --git a/internal/cmd/commands/daemon/server.go b/internal/daemon/cache/server.go similarity index 82% rename from internal/cmd/commands/daemon/server.go rename to internal/daemon/cache/server.go index 4c7113e4c13..a079a8dfda5 100644 --- a/internal/cmd/commands/daemon/server.go +++ b/internal/daemon/cache/server.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" @@ -18,9 +18,11 @@ import ( "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/authtokens" + "github.com/hashicorp/boundary/internal/cache" + cachedb "github.com/hashicorp/boundary/internal/cache/db" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/cmd/base/logging" - "github.com/hashicorp/boundary/internal/daemon/cache" + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/event" "github.com/hashicorp/boundary/internal/util" @@ -42,14 +44,14 @@ type ClientProvider interface { Client(opt ...base.Option) (*api.Client, error) } -type cacheServer struct { - conf *serverConfig +type CacheServer struct { + conf *Config infoKeys []string info map[string]string storeUrl string - store *cache.Store + store *db.DB tickerWg *sync.WaitGroup httpSrv *http.Server @@ -57,53 +59,56 @@ type cacheServer struct { shutdownOnce *sync.Once } -type serverConfig struct { - contextCancel context.CancelFunc - refreshIntervalSeconds int64 - flagDatabaseUrl string - flagStoreDebug bool - flagLogLevel string - flagLogFormat string - logWriter io.Writer +type Config struct { + ContextCancel context.CancelFunc + RefreshIntervalSeconds int64 + DatabaseUrl string + StoreDebug bool + LogLevel string + LogFormat string + LogWriter io.Writer + DotDirectory string } -func (sc *serverConfig) validate(ctx context.Context) error { +func (sc *Config) validate(ctx context.Context) error { const op = "daemon.(serverConfig).validate" switch { - case util.IsNil(sc.logWriter): + case util.IsNil(sc.LogWriter): return errors.New(ctx, errors.InvalidParameter, op, "missing log writter") - case util.IsNil(sc.contextCancel): + case util.IsNil(sc.ContextCancel): return errors.New(ctx, errors.InvalidParameter, op, "missing contextCancel") + case sc.DotDirectory == "": + return errors.New(ctx, errors.InvalidParameter, op, "missing dot directory") } return nil } // can be called before eventing is setup -func newServer(ctx context.Context, conf *serverConfig) (*cacheServer, error) { +func New(ctx context.Context, conf *Config) (*CacheServer, error) { const op = "daemon.newServer" if err := conf.validate(ctx); err != nil { return nil, errors.Wrap(ctx, err, op) } - s := &cacheServer{ + s := &CacheServer{ conf: conf, info: make(map[string]string), infoKeys: make([]string, 0, 20), tickerWg: new(sync.WaitGroup), shutdownOnce: new(sync.Once), } - if err := s.setupLogging(ctx, conf.logWriter); err != nil { + if err := s.setupLogging(ctx, conf.LogWriter); err != nil { return nil, errors.Wrap(ctx, err, op) } return s, nil } -func (s *cacheServer) shutdown(ctx context.Context) error { +func (s *CacheServer) Shutdown(ctx context.Context) error { const op = "daemon.(cacheServer).Shutdown" var shutdownErr error s.shutdownOnce.Do(func() { - if s.conf.contextCancel != nil { - s.conf.contextCancel() + if s.conf.ContextCancel != nil { + s.conf.ContextCancel() } srvCtx, srvCancel := context.WithTimeout(context.Background(), 5*time.Second) defer srvCancel() @@ -157,11 +162,11 @@ func defaultBoundaryTokenReader(ctx context.Context, cp ClientProvider) (cache.B }, nil } -// start will fire up the refresh goroutine and the caching API http server as a +// Serve will fire up the refresh goroutine and the caching API http server as a // daemon. The daemon bits are included so it's easy for CLI cmds to start the // a cache server -func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener, opt ...Option) error { - const op = "daemon.(cacheServer).start" +func (s *CacheServer) Serve(ctx context.Context, cmd Commander, opt ...Option) error { + const op = "daemon.(cacheServer).Serve" switch { case util.IsNil(ctx): return errors.New(ctx, errors.InvalidParameter, op, "context is missing") @@ -178,12 +183,17 @@ func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener, } } + l, err := Listener(ctx, s.conf.DotDirectory) + if err != nil { + return errors.Wrap(ctx, err, op) + } + s.info["Listening address"] = l.Addr().String() s.infoKeys = append(s.infoKeys, "Listening address") - s.info["Store debug"] = strconv.FormatBool(s.conf.flagStoreDebug) + s.info["Store debug"] = strconv.FormatBool(s.conf.StoreDebug) s.infoKeys = append(s.infoKeys, "Store debug") - if s.store, s.storeUrl, err = openStore(ctx, s.conf.flagDatabaseUrl, s.conf.flagStoreDebug); err != nil { + if s.store, s.storeUrl, err = openStore(ctx, s.conf.DatabaseUrl, s.conf.StoreDebug); err != nil { return errors.Wrap(ctx, err, op) } s.info["Database URL"] = s.storeUrl @@ -217,7 +227,11 @@ func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener, } }() - tic, err := newRefreshTicker(ctx, s.conf.refreshIntervalSeconds, repo.Refresh) + refreshService, err := cache.NewRefreshService(ctx, repo) + if err != nil { + return errors.Wrap(ctx, err, op) + } + tic, err := newRefreshTicker(ctx, s.conf.RefreshIntervalSeconds, refreshService) if err != nil { return errors.Wrap(ctx, err, op) } @@ -244,7 +258,7 @@ func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener, } mux.HandleFunc("/v1/tokens", tokenFn) - stopFn, err := newStopHandlerFunc(ctx, s.conf.contextCancel) + stopFn, err := newStopHandlerFunc(ctx, s.conf.ContextCancel) if err != nil { return errors.Wrap(ctx, err, op) } @@ -270,7 +284,7 @@ func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener, return nil } -func (s *cacheServer) printInfo(ctx context.Context) { +func (s *CacheServer) printInfo(ctx context.Context) { const op = "daemon.(cacheServer).printInfo" verInfo := version.Get() if verInfo.Version != "" { @@ -313,7 +327,7 @@ func (s *cacheServer) printInfo(ctx context.Context) { event.WriteSysEvent(ctx, op, strings.Join(output, "\n")) } -func (s *cacheServer) setupLogging(ctx context.Context, w io.Writer) error { +func (s *CacheServer) setupLogging(ctx context.Context, w io.Writer) error { const op = "daemon.(Command).setupLogging" switch { case util.IsNil(w): @@ -321,15 +335,15 @@ func (s *cacheServer) setupLogging(ctx context.Context, w io.Writer) error { } logFormat := logging.StandardFormat - if s.conf.flagLogFormat != "" { + if s.conf.LogFormat != "" { var err error - logFormat, err = logging.ParseLogFormat(s.conf.flagLogFormat) + logFormat, err = logging.ParseLogFormat(s.conf.LogFormat) if err != nil { return fmt.Errorf("%s: %w", op, err) } } - logLevel := strings.ToLower(strings.TrimSpace(s.conf.flagLogLevel)) + logLevel := strings.ToLower(strings.TrimSpace(s.conf.LogLevel)) if logLevel == "" { logLevel = "info" } @@ -419,19 +433,19 @@ func setupEventing(ctx context.Context, logger hclog.Logger, serializationLock * return nil } -func openStore(ctx context.Context, url string, flagDebugStore bool) (*cache.Store, string, error) { +func openStore(ctx context.Context, url string, flagDebugStore bool) (*db.DB, string, error) { const op = "daemon.openStore" var err error + opts := []cachedb.Option{cachedb.WithDebug(flagDebugStore)} switch { case url != "": url, err = parseutil.ParsePath(url) if err != nil && !errors.Is(err, parseutil.ErrNotAUrl) { return nil, "", errors.Wrap(ctx, err, op) } - default: - url = cache.DefaultStoreUrl + opts = append(opts, cachedb.WithUrl(url)) } - store, err := cache.Open(ctx, cache.WithUrl(url), cache.WithDebug(flagDebugStore)) + store, err := cachedb.Open(ctx, opts...) if err != nil { return nil, "", errors.Wrap(ctx, err, op) } diff --git a/internal/cmd/commands/daemon/server_test.go b/internal/daemon/cache/server_test.go similarity index 99% rename from internal/cmd/commands/daemon/server_test.go rename to internal/daemon/cache/server_test.go index 273cd5c8022..0759ec99fbb 100644 --- a/internal/cmd/commands/daemon/server_test.go +++ b/internal/daemon/cache/server_test.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" diff --git a/internal/cmd/commands/daemon/stop_handler.go b/internal/daemon/cache/stop_handler.go similarity index 98% rename from internal/cmd/commands/daemon/stop_handler.go rename to internal/daemon/cache/stop_handler.go index 94d4067c8b2..8a1b97cdace 100644 --- a/internal/cmd/commands/daemon/stop_handler.go +++ b/internal/daemon/cache/stop_handler.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" diff --git a/internal/cmd/commands/daemon/stop_handler_test.go b/internal/daemon/cache/stop_handler_test.go similarity index 98% rename from internal/cmd/commands/daemon/stop_handler_test.go rename to internal/daemon/cache/stop_handler_test.go index ce2b286ad2f..8802305c04f 100644 --- a/internal/cmd/commands/daemon/stop_handler_test.go +++ b/internal/daemon/cache/stop_handler_test.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" diff --git a/internal/cmd/commands/daemon/testing.go b/internal/daemon/cache/testing.go similarity index 76% rename from internal/cmd/commands/daemon/testing.go rename to internal/daemon/cache/testing.go index e22b123206f..7f6b8b83305 100644 --- a/internal/cmd/commands/daemon/testing.go +++ b/internal/daemon/cache/testing.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" @@ -12,12 +12,12 @@ import ( "github.com/hashicorp/boundary/api/authtokens" "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" - "github.com/hashicorp/boundary/internal/daemon/cache" + "github.com/hashicorp/boundary/internal/cache" "github.com/stretchr/testify/require" ) type TestServer struct { - *cacheServer + *CacheServer socketDir string cmd Commander } @@ -31,17 +31,17 @@ func NewTestServer(t *testing.T, cmd Commander, opt ...Option) *TestServer { opts, err := getOpts(opt...) require.NoError(t, err) - cfg := &serverConfig{ - contextCancel: cancel, - refreshIntervalSeconds: DefaultRefreshIntervalSeconds, - flagStoreDebug: opts.withDebug, - logWriter: io.Discard, + cfg := &Config{ + ContextCancel: cancel, + RefreshIntervalSeconds: DefaultRefreshIntervalSeconds, + StoreDebug: opts.withDebug, + LogWriter: io.Discard, } - s, err := newServer(ctx, cfg) + s, err := New(ctx, cfg) require.NoError(t, err) return &TestServer{ - cacheServer: s, + CacheServer: s, socketDir: t.TempDir(), cmd: cmd, } @@ -59,13 +59,10 @@ func (s *TestServer) Serve(t *testing.T, opt ...Option) error { t.Helper() ctx := context.Background() - l, err := listener(ctx, s.socketDir) - require.NoError(t, err) - t.Cleanup(func() { - s.shutdown(ctx) + s.Shutdown(ctx) }) - return s.cacheServer.serve(ctx, s.cmd, l, opt...) + return s.CacheServer.Serve(ctx, s.cmd, opt...) } // AddResources adds targets to the cache for the provided address, token name, @@ -73,7 +70,7 @@ func (s *TestServer) Serve(t *testing.T, opt ...Option) error { func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, tars []*targets.Target, sess []*sessions.Session, atReadFn cache.BoundaryTokenReaderFn) { t.Helper() ctx := context.Background() - r, err := cache.NewRepository(ctx, s.cacheServer.store, &sync.Map{}, s.cmd.ReadTokenFromKeyring, atReadFn) + r, err := cache.NewRepository(ctx, s.CacheServer.store, &sync.Map{}, s.cmd.ReadTokenFromKeyring, atReadFn) require.NoError(t, err) tarFn := func(ctx context.Context, _, tok string) ([]*targets.Target, error) { @@ -88,5 +85,7 @@ func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, tars [] } return sess, nil } - require.NoError(t, r.Refresh(ctx, cache.WithTargetRetrievalFunc(tarFn), cache.WithSessionRetrievalFunc(sessFn))) + rs, err := cache.NewRefreshService(ctx, r) + require.NoError(t, err) + require.NoError(t, rs.Refresh(ctx, cache.WithTargetRetrievalFunc(tarFn), cache.WithSessionRetrievalFunc(sessFn))) } diff --git a/internal/cmd/commands/daemon/ticker.go b/internal/daemon/cache/ticker.go similarity index 83% rename from internal/cmd/commands/daemon/ticker.go rename to internal/daemon/cache/ticker.go index 4ab43178aae..ec8e112f4f2 100644 --- a/internal/cmd/commands/daemon/ticker.go +++ b/internal/daemon/cache/ticker.go @@ -1,35 +1,38 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" "time" - "github.com/hashicorp/boundary/internal/daemon/cache" + "github.com/hashicorp/boundary/internal/cache" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/event" "github.com/hashicorp/boundary/internal/util" ) +const DefaultRefreshIntervalSeconds = 5 * 60 const defaultRefreshInterval = DefaultRefreshIntervalSeconds * time.Second -type refreshFn func(context.Context, ...cache.Option) error +type refreshService interface { + Refresh(context.Context, ...cache.Option) error +} type refreshTicker struct { tickerCtx context.Context refreshInterval time.Duration refreshChan chan struct{} - refreshFn refreshFn + refresher refreshService } -func newRefreshTicker(ctx context.Context, refreshIntervalSeconds int64, refreshFn refreshFn) (*refreshTicker, error) { +func newRefreshTicker(ctx context.Context, refreshIntervalSeconds int64, refresh refreshService) (*refreshTicker, error) { const op = "daemon.newRefreshTicker" switch { case refreshIntervalSeconds == 0: return nil, errors.New(ctx, errors.InvalidParameter, op, "refresh interval seconds is missing") - case util.IsNil(refreshFn): + case util.IsNil(refresh): return nil, errors.New(ctx, errors.InvalidParameter, op, "refreshing function is missing") } @@ -39,7 +42,7 @@ func newRefreshTicker(ctx context.Context, refreshIntervalSeconds int64, refresh } return &refreshTicker{ refreshInterval: refreshInterval, - refreshFn: refreshFn, + refresher: refresh, // We make this channel size 1 so if something happens midway through the refresh // we can immediately refresh again immediately to pick up something that might have been @@ -69,7 +72,7 @@ func (rt *refreshTicker) start(ctx context.Context) { case <-timer.C: case <-rt.refreshChan: } - if err := rt.refreshFn(ctx); err != nil { + if err := rt.refresher.Refresh(ctx); err != nil { event.WriteError(rt.tickerCtx, op, err) } timer.Reset(rt.refreshInterval) diff --git a/internal/cmd/commands/daemon/ticker_test.go b/internal/daemon/cache/ticker_test.go similarity index 66% rename from internal/cmd/commands/daemon/ticker_test.go rename to internal/daemon/cache/ticker_test.go index 110d4d7c1ac..b2d01cb24c9 100644 --- a/internal/cmd/commands/daemon/ticker_test.go +++ b/internal/daemon/cache/ticker_test.go @@ -1,14 +1,14 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" "testing" "time" - "github.com/hashicorp/boundary/internal/daemon/cache" + "github.com/hashicorp/boundary/internal/cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,13 +16,9 @@ import ( func TestTickerRefresh(t *testing.T) { ctx := context.Background() - called := make(chan struct{}) - testFunc := func(context.Context, ...cache.Option) error { - called <- struct{}{} - return nil - } + refresher := &fakeRefresher{make(chan struct{})} - rt, err := newRefreshTicker(ctx, 60, testFunc) + rt, err := newRefreshTicker(ctx, 60, refresher) require.NoError(t, err) tickerCtx, tickerCancel := context.WithCancel(ctx) @@ -32,21 +28,30 @@ func TestTickerRefresh(t *testing.T) { }) // let the normal start ticker refresh things, - <-called + <-refresher.called testCtx, testCancel := context.WithTimeout(ctx, 100*time.Millisecond) defer testCancel() rt.refresh() select { - case <-called: + case <-refresher.called: case <-testCtx.Done(): assert.Fail(t, "timed out waiting for the refresh ") } // wait and make sure we don't get yet another call select { - case <-called: + case <-refresher.called: assert.Fail(t, "received an unexpected refresh call") case <-testCtx.Done(): } } + +type fakeRefresher struct { + called chan struct{} +} + +func (r *fakeRefresher) Refresh(context.Context, ...cache.Option) error { + r.called <- struct{}{} + return nil +} diff --git a/internal/cmd/commands/daemon/token_handler.go b/internal/daemon/cache/token_handler.go similarity index 94% rename from internal/cmd/commands/daemon/token_handler.go rename to internal/daemon/cache/token_handler.go index c64f15eac75..dc5efbe512d 100644 --- a/internal/cmd/commands/daemon/token_handler.go +++ b/internal/daemon/cache/token_handler.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "context" @@ -11,8 +11,8 @@ import ( "net/http" "strings" + "github.com/hashicorp/boundary/internal/cache" "github.com/hashicorp/boundary/internal/cmd/base" - "github.com/hashicorp/boundary/internal/daemon/cache" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/util" ) @@ -21,8 +21,8 @@ type refresher interface { refresh() } -// keyringToken has keyring held auth token information. -type keyringToken struct { +// KeyringToken has keyring held auth token information. +type KeyringToken struct { // The keyring type used by boundary to access the auth token KeyringType string // The token identifier for the provided keyring type that holds the auth token @@ -30,7 +30,7 @@ type keyringToken struct { } // userTokenToAdd is the request body to this handler. -type upsertTokenRequest struct { +type UpsertTokenRequest struct { // BoundaryAddr is a required field for all requests BoundaryAddr string // The id of the auth token asserted to be attempted to be added @@ -40,7 +40,7 @@ type upsertTokenRequest struct { AuthToken string // Keyring is the keyring info used when adding an auth token held in // keyring to the daemon. - Keyring *keyringToken + Keyring *KeyringToken } func newTokenHandlerFunc(ctx context.Context, repo *cache.Repository, refresher refresher) (http.HandlerFunc, error) { @@ -58,7 +58,7 @@ func newTokenHandlerFunc(ctx context.Context, repo *cache.Repository, refresher writeError(w, "method not allowed", http.StatusMethodNotAllowed) return } - var perReq upsertTokenRequest + var perReq UpsertTokenRequest data, err := io.ReadAll(r.Body) if err != nil { diff --git a/internal/cmd/commands/daemon/version_interceptor.go b/internal/daemon/cache/version_interceptor.go similarity index 98% rename from internal/cmd/commands/daemon/version_interceptor.go rename to internal/daemon/cache/version_interceptor.go index 5bda6235163..95c5d9c6775 100644 --- a/internal/cmd/commands/daemon/version_interceptor.go +++ b/internal/daemon/cache/version_interceptor.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "fmt" diff --git a/internal/cmd/commands/daemon/version_interceptor_test.go b/internal/daemon/cache/version_interceptor_test.go similarity index 98% rename from internal/cmd/commands/daemon/version_interceptor_test.go rename to internal/daemon/cache/version_interceptor_test.go index d157b91270d..79d5fe1f960 100644 --- a/internal/cmd/commands/daemon/version_interceptor_test.go +++ b/internal/daemon/cache/version_interceptor_test.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package daemon +package cache import ( "net/http"