From 0a4c1f4d675d57db5422207aabe513202db618f9 Mon Sep 17 00:00:00 2001 From: Matt Brock Date: Thu, 19 Dec 2024 17:47:38 -0600 Subject: [PATCH] Updating to use existing msgraph client --- lib/msgraph/models.go | 10 +- lib/msgraph/paginated.go | 17 +- .../fetchers/azure-sync/msggraphclient.go | 240 ------------------ .../fetchers/azure-sync/principals.go | 92 ++++--- 4 files changed, 82 insertions(+), 277 deletions(-) delete mode 100644 lib/srv/discovery/fetchers/azure-sync/msggraphclient.go diff --git a/lib/msgraph/models.go b/lib/msgraph/models.go index 52c3e97cfec7b..4f2181f81055d 100644 --- a/lib/msgraph/models.go +++ b/lib/msgraph/models.go @@ -28,9 +28,15 @@ type GroupMember interface { isGroupMember() } +type Membership struct { + Type string `json:"@odata.type"` + ID string `json:"id"` +} + type DirectoryObject struct { - ID *string `json:"id,omitempty"` - DisplayName *string `json:"displayName,omitempty"` + ID *string `json:"id,omitempty"` + DisplayName *string `json:"displayName,omitempty"` + MemberOf []Membership `json:"memberOf,omitempty"` } type Group struct { diff --git a/lib/msgraph/paginated.go b/lib/msgraph/paginated.go index 51c587f19d074..cc25162ef849f 100644 --- a/lib/msgraph/paginated.go +++ b/lib/msgraph/paginated.go @@ -54,7 +54,14 @@ func iterateSimple[T any](c *Client, ctx context.Context, endpoint string, f fun func (c *Client) iterate(ctx context.Context, endpoint string, f func(json.RawMessage) bool) error { uri := *c.baseURL uri.Path = path.Join(uri.Path, endpoint) - uri.RawQuery = url.Values{"$top": {strconv.Itoa(c.pageSize)}}.Encode() + uri.RawQuery = url.Values{ + "$top": { + strconv.Itoa(c.pageSize), + }, + "$expand": { + "memberOf", + }, + }.Encode() uriString := uri.String() for uriString != "" { resp, err := c.request(ctx, http.MethodGet, uriString, nil) @@ -101,6 +108,14 @@ func (c *Client) IterateUsers(ctx context.Context, f func(*User) bool) error { return iterateSimple(c, ctx, "users", f) } +// IterateServicePrincipals lists all service principals in the Entra ID directory using pagination. +// `f` will be called for each object in the result set. +// if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop). +// Ref: [https://learn.microsoft.com/en-us/graph/api/user-list]. +func (c *Client) IterateServicePrincipals(ctx context.Context, f func(principal *ServicePrincipal) bool) error { + return iterateSimple(c, ctx, "servicePrincipals", f) +} + // IterateGroupMembers lists all members for the given Entra ID group using pagination. // `f` will be called for each object in the result set. // if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop). diff --git a/lib/srv/discovery/fetchers/azure-sync/msggraphclient.go b/lib/srv/discovery/fetchers/azure-sync/msggraphclient.go deleted file mode 100644 index 75d2960d7fa55..0000000000000 --- a/lib/srv/discovery/fetchers/azure-sync/msggraphclient.go +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Teleport - * Copyright (C) 2024 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package azure_sync - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "time" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" -) - -// GraphClient represents generic MS API client -type GraphClient struct { - token azcore.AccessToken -} - -const ( - usersSuffix = "users" - groupsSuffix = "groups" - servicePrincipalsSuffix = "servicePrincipals" - graphBaseURL = "https://graph.microsoft.com/v1.0" - httpTimeout = time.Second * 30 -) - -// graphError represents MS Graph error -type graphError struct { - E struct { - Code string `json:"code"` - Message string `json:"message"` - } `json:"error"` -} - -// genericGraphResponse represents the utility struct for parsing MS Graph API response -type genericGraphResponse struct { - Context string `json:"@odata.context"` - Count int `json:"@odata.count"` - NextLink string `json:"@odata.nextLink"` - Value json.RawMessage `json:"value"` -} - -// User represents user resource -type User struct { - ID string `json:"id"` - Name string `json:"displayName"` - MemberOf []Membership `json:"memberOf"` -} - -type Membership struct { - Type string `json:"@odata.type"` - ID string `json:"id"` -} - -// request represents generic request structure -type request struct { - // Method HTTP method - Method string - // URL which overrides URL construction - URL *string - // Path to a resource - Path string - // Expand $expand value - Expand []string - // Filter $filter value - Filter string - // Body request body - Body string - // Response represents template structure for a response - Response interface{} - // Err represents template structure for an error - Err error - // SuccessCode http code representing success - SuccessCode int -} - -// GetURL builds the request URL -func (r *request) GetURL() (string, error) { - if r.URL != nil { - return *r.URL, nil - } - u, err := url.Parse(graphBaseURL) - if err != nil { - return "", err - } - - data := url.Values{} - if len(r.Expand) > 0 { - data.Set("$expand", strings.Join(r.Expand, ",")) - } - if r.Filter != "" { - data.Set("$filter", r.Filter) - } - - u.Path = u.Path + "/" + r.Path - u.RawQuery = data.Encode() - - return u.String(), nil -} - -// NewGraphClient creates MS Graph API client -func NewGraphClient(token azcore.AccessToken) *GraphClient { - return &GraphClient{ - token: token, - } -} - -// Error returns error string -func (e graphError) Error() string { - return e.E.Code + " " + e.E.Message -} - -func (c *GraphClient) ListUsers(ctx context.Context) ([]User, error) { - return c.listIdentities(ctx, usersSuffix, []string{"memberOf"}) -} - -func (c *GraphClient) ListGroups(ctx context.Context) ([]User, error) { - return c.listIdentities(ctx, groupsSuffix, []string{"memberOf"}) -} - -func (c *GraphClient) ListServicePrincipals(ctx context.Context) ([]User, error) { - return c.listIdentities(ctx, servicePrincipalsSuffix, []string{"memberOf"}) -} - -func (c *GraphClient) listIdentities(ctx context.Context, idType string, expand []string) ([]User, error) { - var users []User - var nextLink *string - for { - g := &genericGraphResponse{} - req := request{ - Method: http.MethodGet, - Path: idType, - Expand: expand, - Response: &g, - Err: &graphError{}, - URL: nextLink, - } - err := c.request(ctx, req) - if err != nil { - return nil, err - } - var newUsers []User - err = json.NewDecoder(bytes.NewReader(g.Value)).Decode(&newUsers) - if err != nil { - return nil, err - } - users = append(users, newUsers...) - if g.NextLink == "" { - break - } - nextLink = &g.NextLink - } - - return users, nil -} - -// request sends the request to the graph/bot service and returns response body as bytes slice -func (c *GraphClient) request(ctx context.Context, req request) error { - reqUrl, err := req.GetURL() - if err != nil { - return err - } - - r, err := http.NewRequestWithContext(ctx, req.Method, reqUrl, strings.NewReader(req.Body)) - if err != nil { - return err - } - - r.Header.Set("Authorization", "Bearer "+c.token.Token) - r.Header.Set("Content-Type", "application/json") - - client := http.Client{Timeout: httpTimeout} - resp, err := client.Do(r) - if err != nil { - return err - } - - defer func(r *http.Response) { - _ = r.Body.Close() - }(resp) - - b, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - expectedCode := req.SuccessCode - if expectedCode == 0 { - expectedCode = http.StatusOK - } - - if expectedCode == resp.StatusCode { - if req.Response == nil { - return nil - } - - err := json.NewDecoder(bytes.NewReader(b)).Decode(req.Response) - if err != nil { - return err - } - } else { - if req.Err == nil { - return fmt.Errorf("Error requesting MS Graph API: %v", string(b)) - } - - err := json.NewDecoder(bytes.NewReader(b)).Decode(req.Err) - if err != nil { - return err - } - - if req.Err.Error() == "" { - return fmt.Errorf("Error requesting MS Graph API. Expected response code was %v, but is %v", expectedCode, resp.StatusCode) - } - - return req.Err - } - - return nil -} diff --git a/lib/srv/discovery/fetchers/azure-sync/principals.go b/lib/srv/discovery/fetchers/azure-sync/principals.go index f20878e7e3a61..757c78255ed46 100644 --- a/lib/srv/discovery/fetchers/azure-sync/principals.go +++ b/lib/srv/discovery/fetchers/azure-sync/principals.go @@ -20,63 +20,87 @@ package azure_sync import ( "context" - "slices" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" //nolint:unused // used in a dependent PR - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/gravitational/teleport/lib/msgraph" "github.com/gravitational/trace" "google.golang.org/protobuf/types/known/timestamppb" accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" ) -const groupType = "#microsoft.graph.group" //nolint:unused // used in a dependent PR -const defaultGraphScope = "https://graph.microsoft.com/.default" //nolint:unused // used in a dependent PR - // fetchPrincipals fetches the Azure principals (users, groups, and service principals) using the Graph API -func fetchPrincipals(ctx context.Context, subscriptionID string, cred azcore.TokenCredential) ([]*accessgraphv1alpha.AzurePrincipal, error) { //nolint:unused // used in a dependent PR - // Get the graph client - scopes := []string{defaultGraphScope} - token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) +func fetchPrincipals(ctx context.Context, subscriptionID string, cli *msgraph.Client) ([]*accessgraphv1alpha.AzurePrincipal, error) { + // Fetch the users, groups, and service principals + var users []*msgraph.User + err := cli.IterateUsers(ctx, func(user *msgraph.User) bool { + users = append(users, user) + return true + }) if err != nil { return nil, trace.Wrap(err) } - cli := NewGraphClient(token) - // Fetch the users, groups, and managed identities - users, err := cli.ListUsers(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - groups, err := cli.ListGroups(ctx) + var groups []*msgraph.Group + err = cli.IterateGroups(ctx, func(group *msgraph.Group) bool { + groups = append(groups, group) + return true + }) if err != nil { return nil, trace.Wrap(err) } - svcPrincipals, err := cli.ListServicePrincipals(ctx) + + var servicePrincipals []*msgraph.ServicePrincipal + err = cli.IterateServicePrincipals(ctx, func(servicePrincipal *msgraph.ServicePrincipal) bool { + servicePrincipals = append(servicePrincipals, servicePrincipal) + return true + }) if err != nil { return nil, trace.Wrap(err) } - principals := slices.Concat(users, groups, svcPrincipals) - // Return the users as protobuf messages - pbPrincipals := make([]*accessgraphv1alpha.AzurePrincipal, 0, len(principals)) - for _, principal := range principals { - // Extract group membership - memberOf := make([]string, 0) - for _, member := range principal.MemberOf { - if member.Type == groupType { - memberOf = append(memberOf, member.ID) - } + // Return the users, groups, and service principals as protobuf messages + var pbPrincipals []*accessgraphv1alpha.AzurePrincipal + for _, user := range users { + var memberOf []string + for _, member := range user.MemberOf { + memberOf = append(memberOf, member.ID) + } + pbPrincipals = append(pbPrincipals, &accessgraphv1alpha.AzurePrincipal{ + Id: *user.ID, + SubscriptionId: subscriptionID, + LastSyncTime: timestamppb.Now(), + DisplayName: *user.DisplayName, + MemberOf: memberOf, + ObjectType: "user", + }) + } + for _, group := range groups { + var memberOf []string + for _, member := range group.MemberOf { + memberOf = append(memberOf, member.ID) } - // Create the protobuf principal and append it to the list - pbPrincipal := &accessgraphv1alpha.AzurePrincipal{ - Id: principal.ID, + pbPrincipals = append(pbPrincipals, &accessgraphv1alpha.AzurePrincipal{ + Id: *group.ID, SubscriptionId: subscriptionID, LastSyncTime: timestamppb.Now(), - DisplayName: principal.Name, + DisplayName: *group.DisplayName, MemberOf: memberOf, + ObjectType: "group", + }) + } + for _, sp := range servicePrincipals { + var memberOf []string + for _, member := range sp.MemberOf { + memberOf = append(memberOf, member.ID) } - pbPrincipals = append(pbPrincipals, pbPrincipal) + pbPrincipals = append(pbPrincipals, &accessgraphv1alpha.AzurePrincipal{ + Id: *sp.ID, + SubscriptionId: subscriptionID, + LastSyncTime: timestamppb.Now(), + DisplayName: *sp.DisplayName, + MemberOf: memberOf, + ObjectType: "servicePrincipal", + }) } + return pbPrincipals, nil }