Skip to content

Commit

Permalink
Collect resource info in repository_<resource>s.go
Browse files Browse the repository at this point in the history
  • Loading branch information
talanknight committed Sep 25, 2023
1 parent 65212b6 commit e77851e
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 203 deletions.
82 changes: 43 additions & 39 deletions internal/cmd/commands/daemon/search_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ func newSearchTargetsHandlerFunc(ctx context.Context, repo *cache.Repository) (h
case util.IsNil(repo):
return nil, errors.New(ctx, errors.InvalidParameter, op, "repository is missing")
}

searchableResources := map[string]searcher{
"targets": &searchFns[*targets.Target]{
list: repo.ListTargets,
query: repo.QueryTargets,
searchResult: func(t []*targets.Target) *SearchResult {
return &SearchResult{Targets: t}
},
},
"sessions": &searchFns[*sessions.Session]{
list: repo.ListSessions,
query: repo.QuerySessions,
searchResult: func(s []*sessions.Session) *SearchResult {
return &SearchResult{Sessions: s}
},
},
}

return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
filter, err := handlers.NewFilter(ctx, r.URL.Query().Get(filterKey))
Expand Down Expand Up @@ -66,17 +84,12 @@ func newSearchTargetsHandlerFunc(ctx context.Context, repo *cache.Repository) (h

query := r.URL.Query().Get(queryKey)

var res *SearchResult
switch resource {
case "targets":
res, err = searchTargets(r.Context(), repo, authTokenId, query, filter)
case "sessions":
res, err = searchSessions(r.Context(), repo, authTokenId, query, filter)
default:
rSearcher, ok := searchableResources[resource]
if !ok {
writeError(w, fmt.Sprintf("search doesn't support %q resource", resource), http.StatusBadRequest)
return
}

res, err := rSearcher.search(r.Context(), authTokenId, query, filter)
if err != nil {
switch {
case errors.Match(errors.T(errors.InvalidParameter), err):
Expand All @@ -99,50 +112,41 @@ func newSearchTargetsHandlerFunc(ctx context.Context, repo *cache.Repository) (h
}, nil
}

func searchTargets(ctx context.Context, repo *cache.Repository, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error) {
var found []*targets.Target
var err error
switch query {
case "":
found, err = repo.ListTargets(ctx, authTokenId)
default:
found, err = repo.QueryTargets(ctx, authTokenId, query)
}
if err != nil {
return nil, err
}
type searcher interface {
search(ctx context.Context, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error)
}

finalTars := make([]*targets.Target, 0, len(found))
for _, item := range found {
if filter.Match(item) {
finalTars = append(finalTars, item)
}
}
return &SearchResult{
Targets: finalTars,
}, nil
// searchFns is a struct that collects all the functions needed to perform a search
// on a specific resource type.
type searchFns[T any] struct {
// list takes a context and an auth token and returns all resources for the
// user of that auth token.
list func(context.Context, string) ([]T, error)
// query takes a context, an auth token, and a query string and returns all
// resources for that auth token that matches the provided query parameter
query func(context.Context, string, string) ([]T, error)
searchResult func([]T) *SearchResult
}

func searchSessions(ctx context.Context, repo *cache.Repository, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error) {
var found []*sessions.Session
func (l *searchFns[T]) search(ctx context.Context, authTokenId, query string, filter *handlers.Filter) (*SearchResult, error) {
const op = "daemon.(lookupFns).search"
var found []T
var err error
switch query {
case "":
found, err = repo.ListSessions(ctx, authTokenId)
found, err = l.list(ctx, authTokenId)
default:
found, err = repo.QuerySessions(ctx, authTokenId, query)
found, err = l.query(ctx, authTokenId, query)
}
if err != nil {
return nil, err
return nil, errors.Wrap(ctx, err, op)
}

finalSess := make([]*sessions.Session, 0, len(found))
finalResults := make([]T, 0, len(found))
for _, item := range found {
if filter.Match(item) {
finalSess = append(finalSess, item)
finalResults = append(finalResults, item)
}
}
return &SearchResult{
Sessions: finalSess,
}, nil
return l.searchResult(finalResults), nil
}
95 changes: 5 additions & 90 deletions internal/daemon/cache/repository_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,58 +6,11 @@ package cache
import (
"context"
stderrors "errors"
"fmt"

"github.com/hashicorp/boundary/api"
"github.com/hashicorp/boundary/api/authtokens"
"github.com/hashicorp/boundary/api/sessions"
"github.com/hashicorp/boundary/api/targets"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/observability/event"
)

// TargetRetrievalFunc is a function that retrieves targets
// from the provided boundary addr using the provided token.
type TargetRetrievalFunc func(ctx context.Context, addr, token string) ([]*targets.Target, error)

func defaultTargetFunc(ctx context.Context, addr, token string) ([]*targets.Target, error) {
const op = "cache.defaultTargetFunc"
client, err := api.NewClient(&api.Config{
Addr: addr,
Token: token,
})
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
tarClient := targets.NewClient(client)
l, err := tarClient.List(ctx, "global", targets.WithRecursive(true))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return l.Items, nil
}

// SessionRetrievalFunc is a function that retrieves sessions
// from the provided boundary addr using the provided token.
type SessionRetrievalFunc func(ctx context.Context, addr, token string) ([]*sessions.Session, error)

func defaultSessionFunc(ctx context.Context, addr, token string) ([]*sessions.Session, error) {
const op = "cache.defaultSessionFunc"
client, err := api.NewClient(&api.Config{
Addr: addr,
Token: token,
})
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
sClient := sessions.NewClient(client)
l, err := sClient.List(ctx, "global", sessions.WithRecursive(true))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return l.Items, nil
}

// cleanAndPickAuthTokens removes from the cache all auth tokens which are
// evicted from the cache or no longer stored in a keyring and returns the
// remaining ones.
Expand Down Expand Up @@ -106,22 +59,10 @@ func (r *Repository) Refresh(ctx context.Context, opt ...Option) error {
return errors.Wrap(ctx, err, op)
}

opts, err := getOpts(opt...)
if err != nil {
return errors.Wrap(ctx, err, op)
}

us, err := r.listUsers(ctx)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if opts.withTargetRetrievalFunc == nil {
opts.withTargetRetrievalFunc = defaultTargetFunc
}
if opts.withSessionRetrievalFunc == nil {
opts.withSessionRetrievalFunc = defaultSessionFunc
}

var retErr error
for _, u := range us {
tokens, err := r.cleanAndPickAuthTokens(ctx, u)
Expand All @@ -130,39 +71,13 @@ func (r *Repository) Refresh(ctx context.Context, opt ...Option) error {
continue
}

// Find and use a token for retrieving targets
for at, t := range tokens {
resp, err := opts.withTargetRetrievalFunc(ctx, u.Address, t)
if err != nil {
// TODO: If we get an error about the token no longer having
// permissions, remove it.
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}

event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d targets for user %v", len(resp), u))
if err := r.refreshTargets(ctx, u, resp); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for user %v", u)))
}
break
if err := r.refreshTargets(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op))
}

// Find and use a token for retrieving sessions
for at, t := range tokens {
resp, err := opts.withSessionRetrievalFunc(ctx, u.Address, t)
if err != nil {
// TODO: If we get an error about the token no longer having
// permissions, remove it.
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}

event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d sessions for user %v", len(resp), u))
if err := r.refreshSessions(ctx, u, resp); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for user %v", u)))
}
break
if err := r.refreshSessions(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op))
}

}
return retErr
}
50 changes: 8 additions & 42 deletions internal/daemon/cache/repository_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@ import (
"github.com/hashicorp/boundary/api/authtokens"
"github.com/hashicorp/boundary/api/sessions"
"github.com/hashicorp/boundary/api/targets"
"github.com/hashicorp/boundary/internal/daemon/controller"
"github.com/hashicorp/boundary/internal/daemon/worker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"

_ "github.com/hashicorp/boundary/internal/daemon/controller/handlers/targets/tcp"
)

// noopRetrievalFn is a function that satisfies the Refresh's With*RetrievalFn
Expand All @@ -29,6 +25,14 @@ func noopRetrievalFn[T any](context.Context, string, string) ([]T, error) {
return nil, nil
}

// staticRetrievalFn returns a function that satisfies the With*RetrievalFn
// and returns the provided slice and a nil error always
func staticRetrievalFn[T any](ret []T) func(context.Context, string, string) ([]T, error) {
return func(ctx context.Context, s1, s2 string) ([]T, error) {
return ret, nil
}
}

func TestCleanAndPickTokens(t *testing.T) {
ctx := context.Background()
s, err := Open(ctx)
Expand Down Expand Up @@ -283,44 +287,6 @@ func TestRefresh(t *testing.T) {
})
}

func TestDefaultTargetRetrievalFunc(t *testing.T) {
tc := controller.NewTestController(t, nil)
tc.Client().SetToken(tc.Token().Token)
tarClient := targets.NewClient(tc.Client())

tar1, err := tarClient.Create(tc.Context(), "tcp", "p_1234567890", targets.WithName("tar1"), targets.WithTcpTargetDefaultPort(1))
require.NoError(t, err)
require.NotNil(t, tar1)
tar2, err := tarClient.Create(tc.Context(), "tcp", "p_1234567890", targets.WithName("tar2"), targets.WithTcpTargetDefaultPort(2))
require.NoError(t, err)
require.NotNil(t, tar2)

got, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token)
assert.NoError(t, err)
assert.Contains(t, got, tar1.Item)
assert.Contains(t, got, tar2.Item)
}

func TestDefaultSessionRetrievalFunc(t *testing.T) {
tc := controller.NewTestController(t, nil)
tc.Client().SetToken(tc.Token().Token)
tarClient := targets.NewClient(tc.Client())
_ = worker.NewTestWorker(t, &worker.TestWorkerOpts{
InitialUpstreams: tc.ClusterAddrs(),
WorkerAuthKms: tc.Config().WorkerAuthKms,
})

tar1, err := tarClient.Create(tc.Context(), "tcp", "p_1234567890", targets.WithName("tar1"), targets.WithTcpTargetDefaultPort(1), targets.WithAddress("address"))
require.NoError(t, err)
require.NotNil(t, tar1)
_, err = tarClient.AuthorizeSession(tc.Context(), tar1.Item.Id)
assert.NoError(t, err)

got, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token)
assert.NoError(t, err)
assert.Len(t, got, 1)
}

func target(suffix string) *targets.Target {
return &targets.Target{
Id: fmt.Sprintf("target_%s", suffix),
Expand Down
Loading

0 comments on commit e77851e

Please sign in to comment.