From 7ff2de506c77ea66c41a25a343c407affb4b396f Mon Sep 17 00:00:00 2001 From: Todd Date: Fri, 22 Sep 2023 16:48:13 -0700 Subject: [PATCH] Collect resource info in repository_s.go --- .../cmd/commands/daemon/search_handler.go | 82 ++++++++-------- internal/daemon/cache/repository_refresh.go | 95 +------------------ .../daemon/cache/repository_refresh_test.go | 8 ++ internal/daemon/cache/repository_sessions.go | 76 ++++++++++++--- .../daemon/cache/repository_sessions_test.go | 9 +- internal/daemon/cache/repository_targets.go | 73 +++++++++++--- .../daemon/cache/repository_targets_test.go | 9 +- 7 files changed, 191 insertions(+), 161 deletions(-) diff --git a/internal/cmd/commands/daemon/search_handler.go b/internal/cmd/commands/daemon/search_handler.go index 369aff52f5e..9f92daa39a2 100644 --- a/internal/cmd/commands/daemon/search_handler.go +++ b/internal/cmd/commands/daemon/search_handler.go @@ -38,6 +38,24 @@ func newSearchTargetsHandlerFunc(ctx context.Context, repo *cache.Repository) (h case util.IsNil(repo): return nil, errors.New(ctx, errors.InvalidParameter, op, "repository is missing") } + + searchableResources := map[string]searcher{ + "targets": &searchFns[*targets.Target]{ + list: repo.ListTargets, + query: repo.QueryTargets, + searchResult: func(t []*targets.Target) *SearchResult { + return &SearchResult{Targets: t} + }, + }, + "sessions": &searchFns[*sessions.Session]{ + list: repo.ListSessions, + query: repo.QuerySessions, + searchResult: func(s []*sessions.Session) *SearchResult { + return &SearchResult{Sessions: s} + }, + }, + } + return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() filter, err := handlers.NewFilter(ctx, r.URL.Query().Get(filterKey)) @@ -66,17 +84,12 @@ func newSearchTargetsHandlerFunc(ctx context.Context, repo *cache.Repository) (h query := r.URL.Query().Get(queryKey) - var res *SearchResult - switch resource { - case "targets": - res, err = searchTargets(r.Context(), repo, authTokenId, query, filter) - case "sessions": - res, err = searchSessions(r.Context(), repo, authTokenId, query, filter) - default: + rSearcher, ok := searchableResources[resource] + if !ok { writeError(w, fmt.Sprintf("search doesn't support %q resource", resource), http.StatusBadRequest) return } - + res, err := rSearcher.search(r.Context(), authTokenId, query, filter) if err != nil { switch { case errors.Match(errors.T(errors.InvalidParameter), err): @@ -99,50 +112,41 @@ func newSearchTargetsHandlerFunc(ctx context.Context, repo *cache.Repository) (h }, nil } -func searchTargets(ctx context.Context, repo *cache.Repository, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error) { - var found []*targets.Target - var err error - switch query { - case "": - found, err = repo.ListTargets(ctx, authTokenId) - default: - found, err = repo.QueryTargets(ctx, authTokenId, query) - } - if err != nil { - return nil, err - } +type searcher interface { + search(ctx context.Context, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error) +} - finalTars := make([]*targets.Target, 0, len(found)) - for _, item := range found { - if filter.Match(item) { - finalTars = append(finalTars, item) - } - } - return &SearchResult{ - Targets: finalTars, - }, nil +// searchFns is a struct that collects all the functions needed to perform a search +// on a specific resource type. +type searchFns[T any] struct { + // list takes a context and an auth token and returns all resources for the + // user of that auth token. + list func(context.Context, string) ([]T, error) + // query takes a context, an auth token, and a query string and returns all + // resources for that auth token that matches the provided query parameter + query func(context.Context, string, string) ([]T, error) + searchResult func([]T) *SearchResult } -func searchSessions(ctx context.Context, repo *cache.Repository, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error) { - var found []*sessions.Session +func (l *searchFns[T]) search(ctx context.Context, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error) { + const op = "daemon.(lookupFns).search" + var found []T var err error switch query { case "": - found, err = repo.ListSessions(ctx, authTokenId) + found, err = l.list(ctx, authTokenId) default: - found, err = repo.QuerySessions(ctx, authTokenId, query) + found, err = l.query(ctx, authTokenId, query) } if err != nil { - return nil, err + return nil, errors.Wrap(ctx, err, op) } - finalSess := make([]*sessions.Session, 0, len(found)) + finalResults := make([]T, 0, len(found)) for _, item := range found { if filter.Match(item) { - finalSess = append(finalSess, item) + finalResults = append(finalResults, item) } } - return &SearchResult{ - Sessions: finalSess, - }, nil + return l.searchResult(finalResults), nil } diff --git a/internal/daemon/cache/repository_refresh.go b/internal/daemon/cache/repository_refresh.go index 1d070523b8b..bf67ae8961b 100644 --- a/internal/daemon/cache/repository_refresh.go +++ b/internal/daemon/cache/repository_refresh.go @@ -6,58 +6,11 @@ package cache import ( "context" stderrors "errors" - "fmt" - "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/authtokens" - "github.com/hashicorp/boundary/api/sessions" - "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/internal/errors" - "github.com/hashicorp/boundary/internal/observability/event" ) -// TargetRetrievalFunc is a function that retrieves targets -// from the provided boundary addr using the provided token. -type TargetRetrievalFunc func(ctx context.Context, addr, token string) ([]*targets.Target, error) - -func defaultTargetFunc(ctx context.Context, addr, token string) ([]*targets.Target, error) { - const op = "cache.defaultTargetFunc" - client, err := api.NewClient(&api.Config{ - Addr: addr, - Token: token, - }) - if err != nil { - return nil, errors.Wrap(ctx, err, op) - } - tarClient := targets.NewClient(client) - l, err := tarClient.List(ctx, "global", targets.WithRecursive(true)) - if err != nil { - return nil, errors.Wrap(ctx, err, op) - } - return l.Items, nil -} - -// SessionRetrievalFunc is a function that retrieves sessions -// from the provided boundary addr using the provided token. -type SessionRetrievalFunc func(ctx context.Context, addr, token string) ([]*sessions.Session, error) - -func defaultSessionFunc(ctx context.Context, addr, token string) ([]*sessions.Session, error) { - const op = "cache.defaultSessionFunc" - client, err := api.NewClient(&api.Config{ - Addr: addr, - Token: token, - }) - if err != nil { - return nil, errors.Wrap(ctx, err, op) - } - sClient := sessions.NewClient(client) - l, err := sClient.List(ctx, "global", sessions.WithRecursive(true)) - if err != nil { - return nil, errors.Wrap(ctx, err, op) - } - return l.Items, nil -} - // cleanAndPickAuthTokens removes from the cache all auth tokens which are // evicted from the cache or no longer stored in a keyring and returns the // remaining ones. @@ -106,22 +59,10 @@ func (r *Repository) Refresh(ctx context.Context, opt ...Option) error { return errors.Wrap(ctx, err, op) } - opts, err := getOpts(opt...) - if err != nil { - return errors.Wrap(ctx, err, op) - } - us, err := r.listUsers(ctx) if err != nil { return errors.Wrap(ctx, err, op) } - if opts.withTargetRetrievalFunc == nil { - opts.withTargetRetrievalFunc = defaultTargetFunc - } - if opts.withSessionRetrievalFunc == nil { - opts.withSessionRetrievalFunc = defaultSessionFunc - } - var retErr error for _, u := range us { tokens, err := r.cleanAndPickAuthTokens(ctx, u) @@ -130,39 +71,13 @@ func (r *Repository) Refresh(ctx context.Context, opt ...Option) error { continue } - // Find and use a token for retrieving targets - for at, t := range tokens { - resp, err := opts.withTargetRetrievalFunc(ctx, u.Address, t) - if err != nil { - // TODO: If we get an error about the token no longer having - // permissions, remove it. - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) - continue - } - - event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d targets for user %v", len(resp), u)) - if err := r.refreshTargets(ctx, u, resp); err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for user %v", u))) - } - break + if err := r.refreshTargets(ctx, u, tokens, opt...); err != nil { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op)) } - - // Find and use a token for retrieving sessions - for at, t := range tokens { - resp, err := opts.withSessionRetrievalFunc(ctx, u.Address, t) - if err != nil { - // TODO: If we get an error about the token no longer having - // permissions, remove it. - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) - continue - } - - event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d sessions for user %v", len(resp), u)) - if err := r.refreshSessions(ctx, u, resp); err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for user %v", u))) - } - break + if err := r.refreshSessions(ctx, u, tokens, opt...); err != nil { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op)) } + } return retErr } diff --git a/internal/daemon/cache/repository_refresh_test.go b/internal/daemon/cache/repository_refresh_test.go index 2fb41b19abe..8061bacf52c 100644 --- a/internal/daemon/cache/repository_refresh_test.go +++ b/internal/daemon/cache/repository_refresh_test.go @@ -29,6 +29,14 @@ func noopRetrievalFn[T any](context.Context, string, string) ([]T, error) { return nil, nil } +// staticRetrievalFn returns a function that satisfies the With*RetrievalFn +// and returns the provided slice and a nil error always +func staticRetrievalFn[T any](ret []T) func(context.Context, string, string) ([]T, error) { + return func(ctx context.Context, s1, s2 string) ([]T, error) { + return ret, nil + } +} + func TestCleanAndPickTokens(t *testing.T) { ctx := context.Background() s, err := Open(ctx) diff --git a/internal/daemon/cache/repository_sessions.go b/internal/daemon/cache/repository_sessions.go index 58d969c11c8..351fe33c418 100644 --- a/internal/daemon/cache/repository_sessions.go +++ b/internal/daemon/cache/repository_sessions.go @@ -7,46 +7,99 @@ import ( "context" "database/sql" "encoding/json" - stdErrors "errors" + stderrors "errors" "fmt" + "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/observability/event" "github.com/hashicorp/boundary/internal/types/resource" "github.com/hashicorp/boundary/internal/util" "github.com/hashicorp/mql" ) -func (r *Repository) refreshSessions(ctx context.Context, u *user, sessions []*sessions.Session) error { +// SessionRetrievalFunc is a function that retrieves sessions +// from the provided boundary addr using the provided token. +type SessionRetrievalFunc func(ctx context.Context, addr, token string) ([]*sessions.Session, error) + +func defaultSessionFunc(ctx context.Context, addr, token string) ([]*sessions.Session, error) { + const op = "cache.defaultSessionFunc" + client, err := api.NewClient(&api.Config{ + Addr: addr, + Token: token, + }) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + sClient := sessions.NewClient(client) + l, err := sClient.List(ctx, "global", sessions.WithRecursive(true)) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + return l.Items, nil +} + +func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error { const op = "cache.(Repository).refreshSessions" switch { case util.IsNil(u): return errors.New(ctx, errors.InvalidParameter, op, "user is nil") case u.Id == "": return errors.New(ctx, errors.InvalidParameter, op, "user id is missing") + case u.Address == "": + return errors.New(ctx, errors.InvalidParameter, op, "user boundary address is missing") } - foundU := u.clone() - if err := r.rw.LookupById(ctx, foundU); err != nil { - // if this user isn't known, error out. - return errors.Wrap(ctx, err, op, errors.WithMsg("looking up user")) + opts, err := getOpts(opt...) + if err != nil { + return errors.Wrap(ctx, err, op) + } + if opts.withSessionRetrievalFunc == nil { + opts.withSessionRetrievalFunc = defaultSessionFunc + } + + // Find and use a token for retrieving sessions + var gotResponse bool + var resp []*sessions.Session + var retErr error + for at, t := range tokens { + resp, err = opts.withSessionRetrievalFunc(ctx, u.Address, t) + if err != nil { + // TODO: If we get an error about the token no longer having + // permissions, remove it. + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) + continue + } + gotResponse = true + break } - _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { + if retErr != nil { + if saveErr := r.SaveError(ctx, u, resource.Session.String(), retErr); saveErr != nil { + return stderrors.Join(err, errors.Wrap(ctx, saveErr, op)) + } + } + if !gotResponse { + return retErr + } + + event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d sessions for user %v", len(resp), u)) + _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { // TODO: Instead of deleting everything, use refresh tokens and apply the delta if _, err := w.Exec(ctx, "delete from session where user_id = @user_id", - []any{sql.Named("user_id", foundU.Id)}); err != nil { + []any{sql.Named("user_id", u.Id)}); err != nil { return err } - for _, s := range sessions { + for _, s := range resp { item, err := json.Marshal(s) if err != nil { return err } newSession := &Session{ - UserId: foundU.Id, + UserId: u.Id, Id: s.Id, Type: s.Type, Status: s.Status, @@ -64,9 +117,6 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, sessions []*s return nil }) if err != nil { - if saveErr := r.SaveError(ctx, u, resource.Session.String(), err); saveErr != nil { - return stdErrors.Join(err, errors.Wrap(ctx, saveErr, op)) - } return errors.Wrap(ctx, err, op) } return nil diff --git a/internal/daemon/cache/repository_sessions_test.go b/internal/daemon/cache/repository_sessions_test.go index af3755c200e..033b19b9823 100644 --- a/internal/daemon/cache/repository_sessions_test.go +++ b/internal/daemon/cache/repository_sessions_test.go @@ -109,7 +109,8 @@ func TestRepository_refreshSessions(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - err := r.refreshSessions(ctx, tc.u, tc.sess) + err := r.refreshSessions(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(staticRetrievalFn(tc.sess))) if tc.errorContains == "" { assert.NoError(t, err) rw := db.New(s.conn) @@ -192,7 +193,8 @@ func TestRepository_ListSessions(t *testing.T) { Type: "tcp", }, } - require.NoError(t, r.refreshSessions(ctx, u1, ss)) + require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(staticRetrievalFn(ss)))) t.Run("wrong user gets no sessions", func(t *testing.T) { l, err := r.ListSessions(ctx, kt2.AuthTokenId) @@ -298,7 +300,8 @@ func TestRepository_QuerySessions(t *testing.T) { Type: "tcp", }, } - require.NoError(t, r.refreshSessions(ctx, u1, ss)) + require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(staticRetrievalFn(ss)))) t.Run("wrong token gets no sessions", func(t *testing.T) { l, err := r.QuerySessions(ctx, kt2.AuthTokenId, query) diff --git a/internal/daemon/cache/repository_targets.go b/internal/daemon/cache/repository_targets.go index fc0331f7b5b..c90af02b086 100644 --- a/internal/daemon/cache/repository_targets.go +++ b/internal/daemon/cache/repository_targets.go @@ -7,18 +7,41 @@ import ( "context" "database/sql" "encoding/json" - stdErrors "errors" + stderrors "errors" "fmt" + "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/observability/event" "github.com/hashicorp/boundary/internal/types/resource" "github.com/hashicorp/boundary/internal/util" "github.com/hashicorp/mql" ) -func (r *Repository) refreshTargets(ctx context.Context, u *user, targets []*targets.Target) error { +// TargetRetrievalFunc is a function that retrieves targets +// from the provided boundary addr using the provided token. +type TargetRetrievalFunc func(ctx context.Context, addr, token string) ([]*targets.Target, error) + +func defaultTargetFunc(ctx context.Context, addr, token string) ([]*targets.Target, error) { + const op = "cache.defaultTargetFunc" + client, err := api.NewClient(&api.Config{ + Addr: addr, + Token: token, + }) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + tarClient := targets.NewClient(client) + l, err := tarClient.List(ctx, "global", targets.WithRecursive(true)) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + return l.Items, nil +} + +func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error { const op = "cache.(Repository).refreshTargets" switch { case util.IsNil(u): @@ -27,26 +50,53 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, targets []*tar return errors.New(ctx, errors.InvalidParameter, op, "user id is missing") } - foundU := u.clone() - if err := r.rw.LookupById(ctx, foundU); err != nil { - // if this user isn't known, error out. - return errors.Wrap(ctx, err, op, errors.WithMsg("looking up user")) + opts, err := getOpts(opt...) + if err != nil { + return errors.Wrap(ctx, err, op) + } + if opts.withTargetRetrievalFunc == nil { + opts.withTargetRetrievalFunc = defaultTargetFunc + } + + // Find and use a token for retrieving targets + var gotResponse bool + var resp []*targets.Target + var retErr error + for at, t := range tokens { + resp, err = opts.withTargetRetrievalFunc(ctx, u.Address, t) + if err != nil { + // TODO: If we get an error about the token no longer having + // permissions, remove it. + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) + continue + } + gotResponse = true + break + } + if retErr != nil { + if saveErr := r.SaveError(ctx, u, resource.Target.String(), retErr); saveErr != nil { + return stderrors.Join(err, errors.Wrap(ctx, saveErr, op)) + } + } + if !gotResponse { + return retErr } - _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { + event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d targets for user %v", len(resp), u)) + _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { // TODO: Instead of deleting everything, use refresh tokens and apply the delta if _, err := w.Exec(ctx, "delete from target where user_id = @user_id", - []any{sql.Named("user_id", foundU.Id)}); err != nil { + []any{sql.Named("user_id", u.Id)}); err != nil { return err } - for _, t := range targets { + for _, t := range resp { item, err := json.Marshal(t) if err != nil { return err } newTarget := &Target{ - UserId: foundU.Id, + UserId: u.Id, Id: t.Id, Name: t.Name, Description: t.Description, @@ -64,9 +114,6 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, targets []*tar return nil }) if err != nil { - if saveErr := r.SaveError(ctx, u, resource.Target.String(), err); saveErr != nil { - return stdErrors.Join(err, errors.Wrap(ctx, saveErr, op)) - } return errors.Wrap(ctx, err, op) } return nil diff --git a/internal/daemon/cache/repository_targets_test.go b/internal/daemon/cache/repository_targets_test.go index 5063f0488c5..85a5c968331 100644 --- a/internal/daemon/cache/repository_targets_test.go +++ b/internal/daemon/cache/repository_targets_test.go @@ -107,7 +107,8 @@ func TestRepository_refreshTargets(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - err := r.refreshTargets(ctx, tc.u, tc.targets) + err := r.refreshTargets(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(staticRetrievalFn(tc.targets))) if tc.errorContains == "" { assert.NoError(t, err) rw := db.New(s.conn) @@ -186,7 +187,8 @@ func TestRepository_ListTargets(t *testing.T) { SessionMaxSeconds: 333, }, } - require.NoError(t, r.refreshTargets(ctx, u1, ts)) + require.NoError(t, r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(staticRetrievalFn(ts)))) t.Run("wrong user gets no targets", func(t *testing.T) { l, err := r.ListTargets(ctx, kt2.AuthTokenId) @@ -289,7 +291,8 @@ func TestRepository_QueryTargets(t *testing.T) { SessionMaxSeconds: 333, }, } - require.NoError(t, r.refreshTargets(ctx, u1, ts)) + require.NoError(t, r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(staticRetrievalFn(ts)))) t.Run("wrong token gets no targets", func(t *testing.T) { l, err := r.QueryTargets(ctx, kt2.AuthTokenId, query)