Skip to content

Commit

Permalink
Collect resource info in repository_<resource>s.go
Browse files Browse the repository at this point in the history
  • Loading branch information
talanknight committed Sep 25, 2023
1 parent 65212b6 commit 7ff2de5
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 161 deletions.
82 changes: 43 additions & 39 deletions internal/cmd/commands/daemon/search_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ func newSearchTargetsHandlerFunc(ctx context.Context, repo *cache.Repository) (h
case util.IsNil(repo):
return nil, errors.New(ctx, errors.InvalidParameter, op, "repository is missing")
}

searchableResources := map[string]searcher{
"targets": &searchFns[*targets.Target]{
list: repo.ListTargets,
query: repo.QueryTargets,
searchResult: func(t []*targets.Target) *SearchResult {
return &SearchResult{Targets: t}
},
},
"sessions": &searchFns[*sessions.Session]{
list: repo.ListSessions,
query: repo.QuerySessions,
searchResult: func(s []*sessions.Session) *SearchResult {
return &SearchResult{Sessions: s}
},
},
}

return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
filter, err := handlers.NewFilter(ctx, r.URL.Query().Get(filterKey))
Expand Down Expand Up @@ -66,17 +84,12 @@ func newSearchTargetsHandlerFunc(ctx context.Context, repo *cache.Repository) (h

query := r.URL.Query().Get(queryKey)

var res *SearchResult
switch resource {
case "targets":
res, err = searchTargets(r.Context(), repo, authTokenId, query, filter)
case "sessions":
res, err = searchSessions(r.Context(), repo, authTokenId, query, filter)
default:
rSearcher, ok := searchableResources[resource]
if !ok {
writeError(w, fmt.Sprintf("search doesn't support %q resource", resource), http.StatusBadRequest)
return
}

res, err := rSearcher.search(r.Context(), authTokenId, query, filter)
if err != nil {
switch {
case errors.Match(errors.T(errors.InvalidParameter), err):
Expand All @@ -99,50 +112,41 @@ func newSearchTargetsHandlerFunc(ctx context.Context, repo *cache.Repository) (h
}, nil
}

func searchTargets(ctx context.Context, repo *cache.Repository, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error) {
var found []*targets.Target
var err error
switch query {
case "":
found, err = repo.ListTargets(ctx, authTokenId)
default:
found, err = repo.QueryTargets(ctx, authTokenId, query)
}
if err != nil {
return nil, err
}
type searcher interface {
search(ctx context.Context, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error)
}

finalTars := make([]*targets.Target, 0, len(found))
for _, item := range found {
if filter.Match(item) {
finalTars = append(finalTars, item)
}
}
return &SearchResult{
Targets: finalTars,
}, nil
// searchFns is a struct that collects all the functions needed to perform a search
// on a specific resource type.
type searchFns[T any] struct {
// list takes a context and an auth token and returns all resources for the
// user of that auth token.
list func(context.Context, string) ([]T, error)
// query takes a context, an auth token, and a query string and returns all
// resources for that auth token that matches the provided query parameter
query func(context.Context, string, string) ([]T, error)
searchResult func([]T) *SearchResult
}

func searchSessions(ctx context.Context, repo *cache.Repository, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error) {
var found []*sessions.Session
func (l *searchFns[T]) search(ctx context.Context, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error) {
const op = "daemon.(lookupFns).search"
var found []T
var err error
switch query {
case "":
found, err = repo.ListSessions(ctx, authTokenId)
found, err = l.list(ctx, authTokenId)
default:
found, err = repo.QuerySessions(ctx, authTokenId, query)
found, err = l.query(ctx, authTokenId, query)
}
if err != nil {
return nil, err
return nil, errors.Wrap(ctx, err, op)
}

finalSess := make([]*sessions.Session, 0, len(found))
finalResults := make([]T, 0, len(found))
for _, item := range found {
if filter.Match(item) {
finalSess = append(finalSess, item)
finalResults = append(finalResults, item)
}
}
return &SearchResult{
Sessions: finalSess,
}, nil
return l.searchResult(finalResults), nil
}
95 changes: 5 additions & 90 deletions internal/daemon/cache/repository_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,58 +6,11 @@ package cache
import (
"context"
stderrors "errors"
"fmt"

"github.com/hashicorp/boundary/api"
"github.com/hashicorp/boundary/api/authtokens"
"github.com/hashicorp/boundary/api/sessions"
"github.com/hashicorp/boundary/api/targets"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/observability/event"
)

// TargetRetrievalFunc is a function that retrieves targets
// from the provided boundary addr using the provided token.
type TargetRetrievalFunc func(ctx context.Context, addr, token string) ([]*targets.Target, error)

func defaultTargetFunc(ctx context.Context, addr, token string) ([]*targets.Target, error) {
const op = "cache.defaultTargetFunc"
client, err := api.NewClient(&api.Config{
Addr: addr,
Token: token,
})
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
tarClient := targets.NewClient(client)
l, err := tarClient.List(ctx, "global", targets.WithRecursive(true))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return l.Items, nil
}

// SessionRetrievalFunc is a function that retrieves sessions
// from the provided boundary addr using the provided token.
type SessionRetrievalFunc func(ctx context.Context, addr, token string) ([]*sessions.Session, error)

func defaultSessionFunc(ctx context.Context, addr, token string) ([]*sessions.Session, error) {
const op = "cache.defaultSessionFunc"
client, err := api.NewClient(&api.Config{
Addr: addr,
Token: token,
})
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
sClient := sessions.NewClient(client)
l, err := sClient.List(ctx, "global", sessions.WithRecursive(true))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return l.Items, nil
}

// cleanAndPickAuthTokens removes from the cache all auth tokens which are
// evicted from the cache or no longer stored in a keyring and returns the
// remaining ones.
Expand Down Expand Up @@ -106,22 +59,10 @@ func (r *Repository) Refresh(ctx context.Context, opt ...Option) error {
return errors.Wrap(ctx, err, op)
}

opts, err := getOpts(opt...)
if err != nil {
return errors.Wrap(ctx, err, op)
}

us, err := r.listUsers(ctx)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if opts.withTargetRetrievalFunc == nil {
opts.withTargetRetrievalFunc = defaultTargetFunc
}
if opts.withSessionRetrievalFunc == nil {
opts.withSessionRetrievalFunc = defaultSessionFunc
}

var retErr error
for _, u := range us {
tokens, err := r.cleanAndPickAuthTokens(ctx, u)
Expand All @@ -130,39 +71,13 @@ func (r *Repository) Refresh(ctx context.Context, opt ...Option) error {
continue
}

// Find and use a token for retrieving targets
for at, t := range tokens {
resp, err := opts.withTargetRetrievalFunc(ctx, u.Address, t)
if err != nil {
// TODO: If we get an error about the token no longer having
// permissions, remove it.
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}

event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d targets for user %v", len(resp), u))
if err := r.refreshTargets(ctx, u, resp); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for user %v", u)))
}
break
if err := r.refreshTargets(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op))
}

// Find and use a token for retrieving sessions
for at, t := range tokens {
resp, err := opts.withSessionRetrievalFunc(ctx, u.Address, t)
if err != nil {
// TODO: If we get an error about the token no longer having
// permissions, remove it.
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}

event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d sessions for user %v", len(resp), u))
if err := r.refreshSessions(ctx, u, resp); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for user %v", u)))
}
break
if err := r.refreshSessions(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op))
}

}
return retErr
}
8 changes: 8 additions & 0 deletions internal/daemon/cache/repository_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ func noopRetrievalFn[T any](context.Context, string, string) ([]T, error) {
return nil, nil
}

// staticRetrievalFn returns a function that satisfies the With*RetrievalFn
// and returns the provided slice and a nil error always
func staticRetrievalFn[T any](ret []T) func(context.Context, string, string) ([]T, error) {
return func(ctx context.Context, s1, s2 string) ([]T, error) {
return ret, nil
}
}

func TestCleanAndPickTokens(t *testing.T) {
ctx := context.Background()
s, err := Open(ctx)
Expand Down
76 changes: 63 additions & 13 deletions internal/daemon/cache/repository_sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,99 @@ import (
"context"
"database/sql"
"encoding/json"
stdErrors "errors"
stderrors "errors"
"fmt"

"github.com/hashicorp/boundary/api"
"github.com/hashicorp/boundary/api/sessions"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/boundary/internal/util"
"github.com/hashicorp/mql"
)

func (r *Repository) refreshSessions(ctx context.Context, u *user, sessions []*sessions.Session) error {
// SessionRetrievalFunc is a function that retrieves sessions
// from the provided boundary addr using the provided token.
type SessionRetrievalFunc func(ctx context.Context, addr, token string) ([]*sessions.Session, error)

func defaultSessionFunc(ctx context.Context, addr, token string) ([]*sessions.Session, error) {
const op = "cache.defaultSessionFunc"
client, err := api.NewClient(&api.Config{
Addr: addr,
Token: token,
})
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
sClient := sessions.NewClient(client)
l, err := sClient.List(ctx, "global", sessions.WithRecursive(true))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return l.Items, nil
}

func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error {
const op = "cache.(Repository).refreshSessions"
switch {
case util.IsNil(u):
return errors.New(ctx, errors.InvalidParameter, op, "user is nil")
case u.Id == "":
return errors.New(ctx, errors.InvalidParameter, op, "user id is missing")
case u.Address == "":
return errors.New(ctx, errors.InvalidParameter, op, "user boundary address is missing")
}

foundU := u.clone()
if err := r.rw.LookupById(ctx, foundU); err != nil {
// if this user isn't known, error out.
return errors.Wrap(ctx, err, op, errors.WithMsg("looking up user"))
opts, err := getOpts(opt...)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if opts.withSessionRetrievalFunc == nil {
opts.withSessionRetrievalFunc = defaultSessionFunc
}

// Find and use a token for retrieving sessions
var gotResponse bool
var resp []*sessions.Session
var retErr error
for at, t := range tokens {
resp, err = opts.withSessionRetrievalFunc(ctx, u.Address, t)
if err != nil {
// TODO: If we get an error about the token no longer having
// permissions, remove it.
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}
gotResponse = true
break
}

_, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
if retErr != nil {
if saveErr := r.SaveError(ctx, u, resource.Session.String(), retErr); saveErr != nil {
return stderrors.Join(err, errors.Wrap(ctx, saveErr, op))
}
}
if !gotResponse {
return retErr
}

event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d sessions for user %v", len(resp), u))
_, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
// TODO: Instead of deleting everything, use refresh tokens and apply the delta
if _, err := w.Exec(ctx, "delete from session where user_id = @user_id",
[]any{sql.Named("user_id", foundU.Id)}); err != nil {
[]any{sql.Named("user_id", u.Id)}); err != nil {
return err
}

for _, s := range sessions {
for _, s := range resp {
item, err := json.Marshal(s)
if err != nil {
return err
}
newSession := &Session{
UserId: foundU.Id,
UserId: u.Id,
Id: s.Id,
Type: s.Type,
Status: s.Status,
Expand All @@ -64,9 +117,6 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, sessions []*s
return nil
})
if err != nil {
if saveErr := r.SaveError(ctx, u, resource.Session.String(), err); saveErr != nil {
return stdErrors.Join(err, errors.Wrap(ctx, saveErr, op))
}
return errors.Wrap(ctx, err, op)
}
return nil
Expand Down
Loading

0 comments on commit 7ff2de5

Please sign in to comment.