From a1cfe0277ae491708370faade4055ef962e9ccda Mon Sep 17 00:00:00 2001 From: Matt Brock Date: Wed, 8 Jan 2025 22:30:19 -0600 Subject: [PATCH] PR feedback --- lib/msgraph/models.go | 9 ----- lib/msgraph/paginated.go | 34 ++++--------------- .../fetchers/azure-sync/memberships.go | 11 +++--- .../fetchers/azure-sync/principals.go | 1 + .../fetchers/azure-sync/roledefinitions.go | 6 ++-- .../fetchers/azure-sync/virtualmachines.go | 2 +- 6 files changed, 17 insertions(+), 46 deletions(-) diff --git a/lib/msgraph/models.go b/lib/msgraph/models.go index fc2d9a5fee1e1..f76d3deaab858 100644 --- a/lib/msgraph/models.go +++ b/lib/msgraph/models.go @@ -184,12 +184,3 @@ func decodeGroupMember(msg json.RawMessage) (GroupMember, error) { return member, trace.Wrap(err) } - -func decodeDirectoryObject(msg json.RawMessage) (*DirectoryObject, error) { - var d *DirectoryObject - err := json.Unmarshal(msg, &d) - if err != nil { - return nil, trace.Wrap(err) - } - return d, nil -} diff --git a/lib/msgraph/paginated.go b/lib/msgraph/paginated.go index 8140dfd93c1f0..da44a4f442a1b 100644 --- a/lib/msgraph/paginated.go +++ b/lib/msgraph/paginated.go @@ -54,12 +54,7 @@ func iterateSimple[T any](c *Client, ctx context.Context, endpoint string, f fun func (c *Client) iterate(ctx context.Context, endpoint string, f func(json.RawMessage) bool) error { uri := *c.baseURL uri.Path = path.Join(uri.Path, endpoint) - rawQuery := url.Values{ - "$top": { - strconv.Itoa(c.pageSize), - }, - } - uri.RawQuery = rawQuery.Encode() + uri.RawQuery = url.Values{"$top": {strconv.Itoa(c.pageSize)}}.Encode() uriString := uri.String() for uriString != "" { resp, err := c.request(ctx, http.MethodGet, uriString, nil) @@ -114,27 +109,12 @@ func (c *Client) IterateServicePrincipals(ctx context.Context, f func(principal return iterateSimple(c, ctx, "servicePrincipals", f) } -func (c *Client) IterateUserMembership(ctx context.Context, userID string, f func(obj *DirectoryObject) bool) error { - var err error - itErr := c.iterate(ctx, path.Join("users", userID, "memberOf"), func(msg json.RawMessage) bool { - var page []json.RawMessage - if err = json.Unmarshal(msg, &page); err != nil { - return false - } - for _, entry := range page { - var d *DirectoryObject - err := json.Unmarshal(entry, &d) - if err != nil { - return false - } - f(d) - } - return true - }) - if err != nil { - return trace.Wrap(err) - } - return trace.Wrap(itErr) +// IterateUserMembership lists all group memberships for a given user ID as directory objects. +// `f` will be called for each directory object in the result set. +// if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop). +// Ref: [https://learn.microsoft.com/en-us/graph/api/group-list-memberof]. +func (c *Client) IterateUserMembership(ctx context.Context, userID string, f func(object *DirectoryObject) bool) error { + return iterateSimple(c, ctx, path.Join("users", userID, "memberOf"), f) } // IterateGroupMembers lists all members for the given Entra ID group using pagination. diff --git a/lib/srv/discovery/fetchers/azure-sync/memberships.go b/lib/srv/discovery/fetchers/azure-sync/memberships.go index ef022661e07fe..639bab5f4b0ad 100644 --- a/lib/srv/discovery/fetchers/azure-sync/memberships.go +++ b/lib/srv/discovery/fetchers/azure-sync/memberships.go @@ -2,18 +2,17 @@ package azuresync import ( "context" - accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" - "github.com/gravitational/teleport/lib/msgraph" + "github.com/gravitational/trace" "golang.org/x/sync/errgroup" + + accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" + "github.com/gravitational/teleport/lib/msgraph" ) const parallelism = 10 -func expandMemberships( - ctx context.Context, - cli *msgraph.Client, principals []*accessgraphv1alpha.AzurePrincipal, -) ([]*accessgraphv1alpha.AzurePrincipal, error) { +func expandMemberships(ctx context.Context, cli *msgraph.Client, principals []*accessgraphv1alpha.AzurePrincipal) ([]*accessgraphv1alpha.AzurePrincipal, error) { //nolint:unused // invoked in a dependent PR var eg errgroup.Group eg.SetLimit(parallelism) for _, principal := range principals { diff --git a/lib/srv/discovery/fetchers/azure-sync/principals.go b/lib/srv/discovery/fetchers/azure-sync/principals.go index 217ce35ddfc99..27ad1c472c891 100644 --- a/lib/srv/discovery/fetchers/azure-sync/principals.go +++ b/lib/srv/discovery/fetchers/azure-sync/principals.go @@ -20,6 +20,7 @@ package azuresync import ( "context" + "github.com/gravitational/trace" "google.golang.org/protobuf/types/known/timestamppb" diff --git a/lib/srv/discovery/fetchers/azure-sync/roledefinitions.go b/lib/srv/discovery/fetchers/azure-sync/roledefinitions.go index 529d0f2070c9c..485117f898b81 100644 --- a/lib/srv/discovery/fetchers/azure-sync/roledefinitions.go +++ b/lib/srv/discovery/fetchers/azure-sync/roledefinitions.go @@ -20,14 +20,14 @@ package azuresync import ( "context" - "fmt" //nolint:golint // used in a dependent PR - "github.com/gravitational/teleport/lib/utils/slices" + "fmt" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2" "github.com/gravitational/trace" "google.golang.org/protobuf/types/known/timestamppb" accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" + "github.com/gravitational/teleport/lib/utils/slices" ) // RoleDefinitionsClient specifies the methods used to fetch roles from Azure @@ -55,7 +55,7 @@ func fetchRoleDefinitions(ctx context.Context, subscriptionID string, cli RoleDe } pbPerms := make([]*accessgraphv1alpha.AzureRBACPermission, 0, len(roleDef.Properties.Permissions)) for _, perm := range roleDef.Properties.Permissions { - if perm.Actions == nil || perm.NotActions == nil { + if perm.Actions == nil && perm.NotActions == nil { fetchErrs = append(fetchErrs, trace.BadParameter("nil values on Permission object: %v", perm)) continue } diff --git a/lib/srv/discovery/fetchers/azure-sync/virtualmachines.go b/lib/srv/discovery/fetchers/azure-sync/virtualmachines.go index 12658e6a9b009..25876037a54f0 100644 --- a/lib/srv/discovery/fetchers/azure-sync/virtualmachines.go +++ b/lib/srv/discovery/fetchers/azure-sync/virtualmachines.go @@ -28,7 +28,7 @@ import ( accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" ) -const allResourceGroups = "*" //nolint:unused // used in a dependent PR +const allResourceGroups = "*" // VirtualMachinesClient specifies the methods used to fetch virtual machines from Azure type VirtualMachinesClient interface {