diff --git a/lib/msgraph/paginated.go b/lib/msgraph/paginated.go index 89140b1879c5f..27189f25d3697 100644 --- a/lib/msgraph/paginated.go +++ b/lib/msgraph/paginated.go @@ -29,10 +29,17 @@ import ( "github.com/gravitational/trace" ) +const expandParameter = "$expand" +const expandMemberOf = "memberOf" + +type IterateOptions struct { + ExpandMembers bool +} + // iterateSimple implements pagination for "simple" object lists, where additional logic isn't needed -func iterateSimple[T any](c *Client, ctx context.Context, endpoint string, params *url.Values, f func(*T) bool) error { +func iterateSimple[T any](c *Client, ctx context.Context, endpoint string, opts *IterateOptions, f func(*T) bool) error { var err error - itErr := c.iterate(ctx, endpoint, params, func(msg json.RawMessage) bool { + itErr := c.iterate(ctx, endpoint, opts, func(msg json.RawMessage) bool { var page []T if err = json.Unmarshal(msg, &page); err != nil { return false @@ -51,7 +58,7 @@ func iterateSimple[T any](c *Client, ctx context.Context, endpoint string, param } // iterate implements pagination for "list" endpoints. -func (c *Client) iterate(ctx context.Context, endpoint string, params *url.Values, f func(json.RawMessage) bool) error { +func (c *Client) iterate(ctx context.Context, endpoint string, opts *IterateOptions, f func(json.RawMessage) bool) error { uri := *c.baseURL uri.Path = path.Join(uri.Path, endpoint) rawQuery := url.Values{ @@ -59,11 +66,9 @@ func (c *Client) iterate(ctx context.Context, endpoint string, params *url.Value strconv.Itoa(c.pageSize), }, } - if params != nil { - for key, values := range *params { - for _, value := range values { - rawQuery.Add(key, value) - } + if opts != nil { + if opts.ExpandMembers { + rawQuery.Set(expandParameter, expandMemberOf) } } uri.RawQuery = rawQuery.Encode() @@ -93,32 +98,32 @@ func (c *Client) iterate(ctx context.Context, endpoint string, params *url.Value // `f` will be called for each object in the result set. // if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop). // Ref: [https://learn.microsoft.com/en-us/graph/api/application-list]. -func (c *Client) IterateApplications(ctx context.Context, params *url.Values, f func(*Application) bool) error { - return iterateSimple(c, ctx, "applications", params, f) +func (c *Client) IterateApplications(ctx context.Context, opts *IterateOptions, f func(*Application) bool) error { + return iterateSimple(c, ctx, "applications", opts, f) } // IterateGroups lists all groups in the Entra ID directory using pagination. // `f` will be called for each object in the result set. // if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop). // Ref: [https://learn.microsoft.com/en-us/graph/api/group-list]. -func (c *Client) IterateGroups(ctx context.Context, params *url.Values, f func(*Group) bool) error { - return iterateSimple(c, ctx, "groups", params, f) +func (c *Client) IterateGroups(ctx context.Context, opts *IterateOptions, f func(*Group) bool) error { + return iterateSimple(c, ctx, "groups", opts, f) } // IterateUsers lists all users in the Entra ID directory using pagination. // `f` will be called for each object in the result set. // if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop). // Ref: [https://learn.microsoft.com/en-us/graph/api/user-list]. -func (c *Client) IterateUsers(ctx context.Context, params *url.Values, f func(*User) bool) error { - return iterateSimple(c, ctx, "users", params, f) +func (c *Client) IterateUsers(ctx context.Context, opts *IterateOptions, f func(*User) bool) error { + return iterateSimple(c, ctx, "users", opts, f) } // IterateServicePrincipals lists all service principals in the Entra ID directory using pagination. // `f` will be called for each object in the result set. // if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop). // Ref: [https://learn.microsoft.com/en-us/graph/api/user-list]. -func (c *Client) IterateServicePrincipals(ctx context.Context, params *url.Values, f func(principal *ServicePrincipal) bool) error { - return iterateSimple(c, ctx, "servicePrincipals", params, f) +func (c *Client) IterateServicePrincipals(ctx context.Context, opts *IterateOptions, f func(principal *ServicePrincipal) bool) error { + return iterateSimple(c, ctx, "servicePrincipals", opts, f) } // IterateGroupMembers lists all members for the given Entra ID group using pagination. diff --git a/lib/srv/discovery/fetchers/azure-sync/principals.go b/lib/srv/discovery/fetchers/azure-sync/principals.go index f372ee8de756e..7c8aa651c4e43 100644 --- a/lib/srv/discovery/fetchers/azure-sync/principals.go +++ b/lib/srv/discovery/fetchers/azure-sync/principals.go @@ -20,8 +20,6 @@ package azure_sync import ( "context" - "net/url" - "github.com/gravitational/trace" "google.golang.org/protobuf/types/known/timestamppb" @@ -40,13 +38,13 @@ type queryResult struct { // fetchPrincipals fetches the Azure principals (users, groups, and service principals) using the Graph API func fetchPrincipals(ctx context.Context, subscriptionID string, cli *msgraph.Client) ([]*accessgraphv1alpha.AzurePrincipal, error) { //nolint: unused // invoked in a dependent PR - var params = &url.Values{ - "$expand": []string{"memberOf"}, + var opts = &msgraph.IterateOptions{ + ExpandMembers: true, } // Fetch the users, groups, and service principals as directory objects var queryResults []queryResult - err := cli.IterateUsers(ctx, params, func(user *msgraph.User) bool { + err := cli.IterateUsers(ctx, opts, func(user *msgraph.User) bool { res := queryResult{metadata: dirObjMetadata{objectType: "user"}, dirObj: user.DirectoryObject} queryResults = append(queryResults, res) return true @@ -54,7 +52,7 @@ func fetchPrincipals(ctx context.Context, subscriptionID string, cli *msgraph.Cl if err != nil { return nil, trace.Wrap(err) } - err = cli.IterateGroups(ctx, params, func(group *msgraph.Group) bool { + err = cli.IterateGroups(ctx, opts, func(group *msgraph.Group) bool { res := queryResult{metadata: dirObjMetadata{objectType: "group"}, dirObj: group.DirectoryObject} queryResults = append(queryResults, res) return true @@ -62,7 +60,7 @@ func fetchPrincipals(ctx context.Context, subscriptionID string, cli *msgraph.Cl if err != nil { return nil, trace.Wrap(err) } - err = cli.IterateServicePrincipals(ctx, params, func(servicePrincipal *msgraph.ServicePrincipal) bool { + err = cli.IterateServicePrincipals(ctx, opts, func(servicePrincipal *msgraph.ServicePrincipal) bool { res := queryResult{metadata: dirObjMetadata{objectType: "servicePrincipal"}, dirObj: servicePrincipal.DirectoryObject} queryResults = append(queryResults, res) return true