From 7bdc2d75c4eda17e3de6bb7e9416c182e7b217ea Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 26 Feb 2025 16:36:35 -0700 Subject: [PATCH] azure: add support for service principals --- pkg/directory/azure/api.go | 4 ++ pkg/directory/azure/azure_test.go | 27 +++++++- pkg/directory/azure/delta.go | 105 +++++++++++++++++++++++++----- 3 files changed, 118 insertions(+), 18 deletions(-) diff --git a/pkg/directory/azure/api.go b/pkg/directory/azure/api.go index 92e86f6..a6f0785 100644 --- a/pkg/directory/azure/api.go +++ b/pkg/directory/azure/api.go @@ -7,6 +7,10 @@ type ( ID string `json:"id"` DisplayName string `json:"displayName"` } + apiServicePrincipal struct { + ID string `json:"id"` + DisplayName string `json:"displayName"` + } apiUser struct { ID string `json:"id"` DisplayName string `json:"displayName"` diff --git a/pkg/directory/azure/azure_test.go b/pkg/directory/azure/azure_test.go index 946f192..c11a6a0 100644 --- a/pkg/directory/azure/azure_test.go +++ b/pkg/directory/azure/azure_test.go @@ -17,7 +17,7 @@ import ( "github.com/pomerium/datasource/pkg/directory" ) -type M = map[string]interface{} +type M = map[string]any func newMockAPI(t *testing.T, _ *httptest.Server) http.Handler { t.Helper() @@ -59,6 +59,7 @@ func newMockAPI(t *testing.T, _ *httptest.Server) http.Handler { "displayName": "Admin Group", "members@delta": []M{ {"@odata.type": "#microsoft.graph.user", "id": "user-1"}, + {"@odata.type": "#microsoft.graph.servicePrincipal", "id": "service-principal-1"}, }, }, { @@ -73,6 +74,20 @@ func newMockAPI(t *testing.T, _ *httptest.Server) http.Handler { }, }) }) + r.Get("/servicePrincipals/delta", func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(M{ + "value": []M{ + { + "id": "service-principal-1", + "displayName": "Service Principal 1", + }, + { + "id": "service-principal-2", + "displayName": "Service Principal 2", + }, + }, + }) + }) r.Get("/users/delta", func(w http.ResponseWriter, _ *http.Request) { _ = json.NewEncoder(w).Encode(M{ "value": []M{ @@ -149,6 +164,16 @@ func TestProvider_GetDirectory(t *testing.T) { {ID: "test", Name: "Test Group"}, }, groups) assert.Equal(t, []directory.User{ + { + ID: "service-principal-1", + GroupIDs: []string{"admin"}, + DisplayName: "Service Principal 1", + }, + { + ID: "service-principal-2", + GroupIDs: []string{}, + DisplayName: "Service Principal 2", + }, { ID: "user-1", GroupIDs: []string{"admin"}, diff --git a/pkg/directory/azure/delta.go b/pkg/directory/azure/delta.go index 6b9f8e2..c031432 100644 --- a/pkg/directory/azure/delta.go +++ b/pkg/directory/azure/delta.go @@ -2,6 +2,7 @@ package azure import ( "context" + "errors" "net/url" "sort" @@ -9,17 +10,20 @@ import ( ) const ( - groupsDeltaPath = "/v1.0/groups/delta" - usersDeltaPath = "/v1.0/users/delta" + groupsDeltaPath = "/v1.0/groups/delta" + servicePrincipalsDeltaPath = "/v1.0/servicePrincipals/delta" + usersDeltaPath = "/v1.0/users/delta" ) type ( deltaCollection struct { - provider *Provider - groups map[string]deltaGroup - groupDeltaLink string - users map[string]deltaUser - userDeltaLink string + provider *Provider + groups map[string]deltaGroup + groupDeltaLink string + servicePrincipals map[string]deltaServicePrincipal + servicePrincipalDeltaLink string + users map[string]deltaUser + userDeltaLink string } deltaGroup struct { id string @@ -35,13 +39,18 @@ type ( displayName string email string } + deltaServicePrincipal struct { + id string + displayName string + } ) func newDeltaCollection(p *Provider) *deltaCollection { return &deltaCollection{ - provider: p, - groups: make(map[string]deltaGroup), - users: make(map[string]deltaUser), + provider: p, + groups: make(map[string]deltaGroup), + users: make(map[string]deltaUser), + servicePrincipals: make(map[string]deltaServicePrincipal), } } @@ -58,12 +67,11 @@ func newDeltaCollection(p *Provider) *deltaCollection { // // Only the changed groups/members are returned. Removed groups/members have an @removed property. func (dc *deltaCollection) Sync(ctx context.Context) error { - err := dc.syncGroups(ctx) - if err != nil { - return err - } - - return dc.syncUsers(ctx) + return errors.Join( + dc.syncGroups(ctx), + dc.syncServicePrincipals(ctx), + dc.syncUsers(ctx), + ) } func (dc *deltaCollection) syncGroups(ctx context.Context) error { @@ -126,6 +134,50 @@ func (dc *deltaCollection) syncGroups(ctx context.Context) error { } } +func (dc *deltaCollection) syncServicePrincipals(ctx context.Context) error { + apiURL := dc.servicePrincipalDeltaLink + + // if no delta link is set yet, start the initial fill + if apiURL == "" { + apiURL = dc.provider.cfg.graphURL.ResolveReference(&url.URL{ + Path: servicePrincipalsDeltaPath, + RawQuery: url.Values{ + "$select": {"displayName"}, + }.Encode(), + }).String() + } + + for { + var res servicePrincipalsDeltaResponse + err := dc.provider.api(ctx, apiURL, &res) + if err != nil { + return err + } + + for _, sp := range res.Value { + // if removed exists, the service principal was deleted + if sp.Removed != nil { + delete(dc.servicePrincipals, sp.ID) + continue + } + dc.servicePrincipals[sp.ID] = deltaServicePrincipal{ + id: sp.ID, + displayName: sp.DisplayName, + } + } + + switch { + case res.NextLink != "": + // when there's a next link we will query again + apiURL = res.NextLink + default: + // once no next link is set anymore, we save the delta link and return + dc.servicePrincipalDeltaLink = res.DeltaLink + return nil + } + } +} + func (dc *deltaCollection) syncUsers(ctx context.Context) error { apiURL := dc.userDeltaLink @@ -186,7 +238,8 @@ func (dc *deltaCollection) CurrentUserGroups() ([]directory.Group, []directory.U switch m.memberType { case "#microsoft.graph.group": groupIDs = append(groupIDs, m.id) - case "#microsoft.graph.user": + case "#microsoft.graph.servicePrincipal", + "#microsoft.graph.user": userIDs = append(userIDs, m.id) } } @@ -197,6 +250,13 @@ func (dc *deltaCollection) CurrentUserGroups() ([]directory.Group, []directory.U }) var users []directory.User + for _, sp := range dc.servicePrincipals { + users = append(users, directory.User{ + ID: sp.id, + GroupIDs: groupLookup.getGroupIDsForUser(sp.id), + DisplayName: sp.displayName, + }) + } for _, u := range dc.users { users = append(users, directory.User{ ID: u.id, @@ -235,6 +295,17 @@ type ( Removed *deltaResponseRemoved `json:"@removed,omitempty"` } + servicePrincipalsDeltaResponse struct { + Context string `json:"@odata.context"` + NextLink string `json:"@odata.nextLink,omitempty"` + DeltaLink string `json:"@odata.deltaLink,omitempty"` + Value []servicePrincipalsDeltaResponseServicePrincipal `json:"value"` + } + servicePrincipalsDeltaResponseServicePrincipal struct { + apiServicePrincipal + Removed *deltaResponseRemoved `json:"@removed,omitempty"` + } + usersDeltaResponse struct { Context string `json:"@odata.context"` NextLink string `json:"@odata.nextLink,omitempty"`