From f7ea5407827624a459e9c5545446c24de40dda73 Mon Sep 17 00:00:00 2001 From: Matt Brock Date: Wed, 22 Jan 2025 14:27:15 -0600 Subject: [PATCH] Adding Azure sync functionality which can be used by the Azure Fetcher (#50367) * Protobuf and configuration for Access Graph Azure Discovery * Adding the Azure sync module functions along with new cloud client functionality * Moving reconciliation to the upstream azure sync PR * Moving reconciliation test to the upstream azure sync PR * Fixing rebase after protobuf gen * Updating to use existing msgraph client * PR feedback * Using variadic options * Removing memberOf expansion * Expanding memberships by calling memberOf on each user * PR feedback * Rebase go.sum stuff * Go mod tidy * Fixing go.mod * Update lib/msgraph/paginated.go Co-authored-by: Tiago Silva * PR feedback * e ref update * Adding the Azure sync module functions along with new cloud client functionality * Protobuf and configuration for Access Graph Azure Discovery * Adding Azure sync functionality which can be called by the Azure fetcher * Protobuf update * Update sync process to use msgraph client * Conformant package name * Invoking membership expansion * Setting principals before expansion * Removing msgraphclient * Update e ref * Linting * PR feedback * Adding test names to reconciliation tests * Adding channel buffer * Going back to just reading from channel * Linting * PR feedback * PR feedback * PR feedback * Apply suggestions from code review Co-authored-by: Tiago Silva * PR feedback * Fixing flaky test * Lint * Fix imports --------- Co-authored-by: Tiago Silva --- .../fetchers/azure-sync/azure-sync.go | 251 ++++++++++++++++++ .../fetchers/azure-sync/azure-sync_test.go | 234 ++++++++++++++++ .../fetchers/azure-sync/reconcile.go | 165 ++++++++++++ .../fetchers/azure-sync/reconcile_test.go | 197 ++++++++++++++ lib/utils/slices/slices.go | 14 + lib/utils/slices/slices_test.go | 90 +++++++ 6 files changed, 951 insertions(+) create mode 100644 lib/srv/discovery/fetchers/azure-sync/azure-sync.go create mode 100644 lib/srv/discovery/fetchers/azure-sync/azure-sync_test.go create mode 100644 lib/srv/discovery/fetchers/azure-sync/reconcile.go create mode 100644 lib/srv/discovery/fetchers/azure-sync/reconcile_test.go diff --git a/lib/srv/discovery/fetchers/azure-sync/azure-sync.go b/lib/srv/discovery/fetchers/azure-sync/azure-sync.go new file mode 100644 index 0000000000000..84dac1d26ec3c --- /dev/null +++ b/lib/srv/discovery/fetchers/azure-sync/azure-sync.go @@ -0,0 +1,251 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azuresync + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/gravitational/trace" + "golang.org/x/sync/errgroup" + + accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" + "github.com/gravitational/teleport/lib/cloud/azure" + "github.com/gravitational/teleport/lib/msgraph" + "github.com/gravitational/teleport/lib/utils/slices" +) + +// fetcherConcurrency is an arbitrary per-resource type concurrency to ensure significant throughput. As we increase +// the number of resource types, we may increase this value or use some other approach to fetching concurrency. +const fetcherConcurrency = 4 + +// Config defines parameters required for fetching resources from Azure +type Config struct { + // SubscriptionID is the Azure subscriptipn ID + SubscriptionID string + // Integration is the name of the associated Teleport integration + Integration string + // DiscoveryConfigName is the name of this Discovery configuration + DiscoveryConfigName string +} + +// Resources represents the set of resources fetched from Azure +type Resources struct { + // Principals are Azure users, groups, and service principals + Principals []*accessgraphv1alpha.AzurePrincipal + // RoleDefinitions are Azure role definitions + RoleDefinitions []*accessgraphv1alpha.AzureRoleDefinition + // RoleAssignments are Azure role assignments + RoleAssignments []*accessgraphv1alpha.AzureRoleAssignment + // VirtualMachines are Azure virtual machines + VirtualMachines []*accessgraphv1alpha.AzureVirtualMachine +} + +// Fetcher provides the functionality for fetching resources from Azure +type Fetcher struct { + // Config is the configuration values for this fetcher + Config + // lastError is the last error returned from polling + lastError error + // lastDiscoveredResources is the number of resources last returned from polling + lastDiscoveredResources uint64 + // lastResult is the last set of resources returned from polling + lastResult *Resources + + // graphClient is the MS graph client for fetching principals + graphClient *msgraph.Client + // roleAssignClient is the Azure client for fetching role assignments + roleAssignClient RoleAssignmentsClient + // roleDefClient is the Azure client for fetching role definitions + roleDefClient RoleDefinitionsClient + // vmClient is the Azure client for fetching virtual machines + vmClient VirtualMachinesClient +} + +// NewFetcher returns a new fetcher based on configuration parameters +func NewFetcher(cfg Config, ctx context.Context) (*Fetcher, error) { + // Establish the credential from the managed identity + cred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, trace.Wrap(err) + } + + // Create the clients for the fetcher + graphClient, err := msgraph.NewClient(msgraph.Config{ + TokenProvider: cred, + }) + if err != nil { + return nil, trace.Wrap(err) + } + roleAssignClient, err := azure.NewRoleAssignmentsClient(cfg.SubscriptionID, cred, nil) + if err != nil { + return nil, trace.Wrap(err) + } + roleDefClient, err := azure.NewRoleDefinitionsClient(cfg.SubscriptionID, cred, nil) + if err != nil { + return nil, trace.Wrap(err) + } + vmClient, err := azure.NewVirtualMachinesClient(cfg.SubscriptionID, cred, nil) + if err != nil { + return nil, trace.Wrap(err) + } + + return &Fetcher{ + Config: cfg, + lastResult: &Resources{}, + graphClient: graphClient, + roleAssignClient: roleAssignClient, + roleDefClient: roleDefClient, + vmClient: vmClient, + }, nil +} + +const ( + featNamePrincipals = "azure/principals" + featNameRoleDefinitions = "azure/roledefinitions" + featNameRoleAssignments = "azure/roleassignments" + featNameVms = "azure/virtualmachines" +) + +// Features is a set of booleans that are received from the Access Graph to indicate which resources it can receive +type Features struct { + // Principals indicates Azure principals can be be fetched + Principals bool + // RoleDefinitions indicates Azure role definitions can be fetched + RoleDefinitions bool + // RoleAssignments indicates Azure role assignments can be fetched + RoleAssignments bool + // VirtualMachines indicates Azure virtual machines can be fetched + VirtualMachines bool +} + +// BuildFeatures builds the feature flags based on supported types returned by Access Graph Azure endpoints. +func BuildFeatures(values ...string) Features { + features := Features{} + for _, value := range values { + switch value { + case featNamePrincipals: + features.Principals = true + case featNameRoleAssignments: + features.RoleAssignments = true + case featNameRoleDefinitions: + features.RoleDefinitions = true + case featNameVms: + features.VirtualMachines = true + } + } + return features +} + +// Poll fetches and deduplicates Azure resources specified by the Access Graph +func (f *Fetcher) Poll(ctx context.Context, feats Features) (*Resources, error) { + res, err := f.fetch(ctx, feats) + if res == nil { + return nil, trace.Wrap(err) + } + res.Principals = slices.DeduplicateKey(res.Principals, azurePrincipalsKey) + res.RoleAssignments = slices.DeduplicateKey(res.RoleAssignments, azureRoleAssignKey) + res.RoleDefinitions = slices.DeduplicateKey(res.RoleDefinitions, azureRoleDefKey) + res.VirtualMachines = slices.DeduplicateKey(res.VirtualMachines, azureVmKey) + return res, trace.Wrap(err) +} + +// fetch returns the resources specified by the Access Graph +func (f *Fetcher) fetch(ctx context.Context, feats Features) (*Resources, error) { + // Accumulate Azure resources + eg, ctx := errgroup.WithContext(ctx) + eg.SetLimit(fetcherConcurrency) + var result = &Resources{} + // we use a larger value (50) here so there is always room for any returned error to be sent to errsCh without blocking. + errsCh := make(chan error, 50) + if feats.Principals { + eg.Go(func() error { + principals, err := fetchPrincipals(ctx, f.SubscriptionID, f.graphClient) + if err != nil { + errsCh <- err + return nil + } + principals, err = expandMemberships(ctx, f.graphClient, principals) + if err != nil { + errsCh <- err + return nil + } + result.Principals = principals + return nil + }) + } + if feats.RoleAssignments { + eg.Go(func() error { + roleAssigns, err := fetchRoleAssignments(ctx, f.SubscriptionID, f.roleAssignClient) + if err != nil { + errsCh <- err + return nil + } + result.RoleAssignments = roleAssigns + return nil + }) + } + if feats.RoleDefinitions { + eg.Go(func() error { + roleDefs, err := fetchRoleDefinitions(ctx, f.SubscriptionID, f.roleDefClient) + if err != nil { + errsCh <- err + return nil + } + result.RoleDefinitions = roleDefs + return nil + }) + } + if feats.VirtualMachines { + eg.Go(func() error { + vms, err := fetchVirtualMachines(ctx, f.SubscriptionID, f.vmClient) + if err != nil { + errsCh <- err + return nil + } + result.VirtualMachines = vms + return nil + }) + } + + // Return the result along with any errors collected + _ = eg.Wait() + close(errsCh) + return result, trace.NewAggregateFromChannel(errsCh, context.WithoutCancel(ctx)) +} + +// Status returns the number of resources last fetched and/or the last fetching/reconciling error +func (f *Fetcher) Status() (uint64, error) { + return f.lastDiscoveredResources, f.lastError +} + +// DiscoveryConfigName returns the name of the configured discovery +func (f *Fetcher) DiscoveryConfigName() string { + return f.Config.DiscoveryConfigName +} + +// IsFromDiscoveryConfig returns whether the discovery is from configuration or dynamic +func (f *Fetcher) IsFromDiscoveryConfig() bool { + return f.Config.DiscoveryConfigName != "" +} + +// GetSubscriptionID returns the ID of the Azure subscription +func (f *Fetcher) GetSubscriptionID() string { + return f.Config.SubscriptionID +} diff --git a/lib/srv/discovery/fetchers/azure-sync/azure-sync_test.go b/lib/srv/discovery/fetchers/azure-sync/azure-sync_test.go new file mode 100644 index 0000000000000..261555b21578e --- /dev/null +++ b/lib/srv/discovery/fetchers/azure-sync/azure-sync_test.go @@ -0,0 +1,234 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azuresync + +import ( + "context" + "fmt" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/stretchr/testify/require" +) + +type testRoleDefCli struct { + returnErr bool + roleDefs []*armauthorization.RoleDefinition +} + +func (t testRoleDefCli) ListRoleDefinitions(ctx context.Context, scope string) ([]*armauthorization.RoleDefinition, error) { + if t.returnErr { + return nil, fmt.Errorf("error") + } + return t.roleDefs, nil +} + +type testRoleAssignCli struct { + returnErr bool + roleAssigns []*armauthorization.RoleAssignment +} + +func (t testRoleAssignCli) ListRoleAssignments(ctx context.Context, scope string) ([]*armauthorization.RoleAssignment, error) { + if t.returnErr { + return nil, fmt.Errorf("error") + } + return t.roleAssigns, nil +} + +type testVmCli struct { + returnErr bool + vms []*armcompute.VirtualMachine +} + +func (t testVmCli) ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error) { + if t.returnErr { + return nil, fmt.Errorf("error") + } + return t.vms, nil +} + +func newRoleDef(id string, name string) *armauthorization.RoleDefinition { + roleName := "test_role_name" + action1 := "Microsoft.Compute/virtualMachines/read" + action2 := "Microsoft.Compute/virtualMachines/*" + action3 := "Microsoft.Compute/*" + return &armauthorization.RoleDefinition{ + ID: &id, + Name: &name, + Properties: &armauthorization.RoleDefinitionProperties{ + Permissions: []*armauthorization.Permission{ + { + Actions: []*string{&action1, &action2}, + }, + { + Actions: []*string{&action3}, + }, + }, + RoleName: &roleName, + }, + } +} + +func newRoleAssign(id string, name string) *armauthorization.RoleAssignment { + scope := "test_scope" + principalId := "test_principal_id" + roleDefId := "test_role_def_id" + return &armauthorization.RoleAssignment{ + ID: &id, + Name: &name, + Properties: &armauthorization.RoleAssignmentProperties{ + PrincipalID: &principalId, + RoleDefinitionID: &roleDefId, + Scope: &scope, + }, + } +} + +func newVm(id string, name string) *armcompute.VirtualMachine { + return &armcompute.VirtualMachine{ + ID: &id, + Name: &name, + } +} + +func TestPoll(t *testing.T) { + roleDefs := []*armauthorization.RoleDefinition{ + newRoleDef("id1", "name1"), + } + roleAssigns := []*armauthorization.RoleAssignment{ + newRoleAssign("id1", "name1"), + } + vms := []*armcompute.VirtualMachine{ + newVm("id1", "name2"), + } + roleDefClient := testRoleDefCli{} + roleAssignClient := testRoleAssignCli{} + vmClient := testVmCli{} + fetcher := Fetcher{ + Config: Config{SubscriptionID: "1234567890"}, + lastResult: &Resources{}, + roleDefClient: &roleDefClient, + roleAssignClient: &roleAssignClient, + vmClient: &vmClient, + } + ctx := context.Background() + allFeats := Features{ + RoleDefinitions: true, + RoleAssignments: true, + VirtualMachines: true, + } + noVmsFeats := allFeats + noVmsFeats.VirtualMachines = false + + tests := []struct { + name string + returnErr bool + roleDefs []*armauthorization.RoleDefinition + roleAssigns []*armauthorization.RoleAssignment + vms []*armcompute.VirtualMachine + feats Features + }{ + // Process no results from clients + { + name: "WithoutResults", + returnErr: false, + roleDefs: []*armauthorization.RoleDefinition{}, + roleAssigns: []*armauthorization.RoleAssignment{}, + vms: []*armcompute.VirtualMachine{}, + feats: allFeats, + }, + // Process test results from clients + { + name: "WithResults", + returnErr: false, + roleDefs: roleDefs, + roleAssigns: roleAssigns, + vms: vms, + feats: allFeats, + }, + // Handle errors from clients + { + name: "PollErrors", + returnErr: true, + roleDefs: roleDefs, + roleAssigns: roleAssigns, + vms: vms, + feats: allFeats, + }, + // Handle VM features being disabled + { + name: "NoVmsFeats", + returnErr: false, + roleDefs: roleDefs, + roleAssigns: roleAssigns, + vms: vms, + feats: noVmsFeats, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set the test data + roleDefClient.returnErr = tt.returnErr + roleDefClient.roleDefs = tt.roleDefs + roleAssignClient.returnErr = tt.returnErr + roleAssignClient.roleAssigns = tt.roleAssigns + vmClient.returnErr = tt.returnErr + vmClient.vms = tt.vms + + // Poll for resources + resources, err := fetcher.Poll(ctx, tt.feats) + + // Require no error unless otherwise specified + if tt.returnErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + // Verify the results, based on the features set + require.NotNil(t, resources) + require.Equal(t, tt.feats.RoleDefinitions == false || len(tt.roleDefs) == 0, len(resources.RoleDefinitions) == 0) + for idx, resource := range resources.RoleDefinitions { + roleDef := tt.roleDefs[idx] + require.Equal(t, *roleDef.ID, resource.Id) + require.Equal(t, fetcher.SubscriptionID, resource.SubscriptionId) + require.Equal(t, *roleDef.Properties.RoleName, resource.Name) + require.Len(t, roleDef.Properties.Permissions, len(resource.Permissions)) + } + require.Equal(t, tt.feats.RoleAssignments == false || len(tt.roleAssigns) == 0, len(resources.RoleAssignments) == 0) + for idx, resource := range resources.RoleAssignments { + roleAssign := tt.roleAssigns[idx] + require.Equal(t, *roleAssign.ID, resource.Id) + require.Equal(t, fetcher.SubscriptionID, resource.SubscriptionId) + require.Equal(t, *roleAssign.Properties.PrincipalID, resource.PrincipalId) + require.Equal(t, *roleAssign.Properties.RoleDefinitionID, resource.RoleDefinitionId) + require.Equal(t, *roleAssign.Properties.Scope, resource.Scope) + } + require.Equal(t, tt.feats.VirtualMachines == false || len(tt.vms) == 0, len(resources.VirtualMachines) == 0) + for idx, resource := range resources.VirtualMachines { + vm := tt.vms[idx] + require.Equal(t, *vm.ID, resource.Id) + require.Equal(t, fetcher.SubscriptionID, resource.SubscriptionId) + require.Equal(t, *vm.Name, resource.Name) + } + }) + } +} diff --git a/lib/srv/discovery/fetchers/azure-sync/reconcile.go b/lib/srv/discovery/fetchers/azure-sync/reconcile.go new file mode 100644 index 0000000000000..a874f48215811 --- /dev/null +++ b/lib/srv/discovery/fetchers/azure-sync/reconcile.go @@ -0,0 +1,165 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azuresync + +import ( + "fmt" + + "google.golang.org/protobuf/proto" + + accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" + "github.com/gravitational/teleport/lib/utils/slices" +) + +// MergeResources merges Azure resources fetched from multiple configured Azure fetchers +func MergeResources(results ...*Resources) *Resources { + if len(results) == 0 { + return &Resources{} + } + if len(results) == 1 { + return results[0] + } + result := &Resources{} + for _, r := range results { + result.Principals = append(result.Principals, r.Principals...) + result.RoleAssignments = append(result.RoleAssignments, r.RoleAssignments...) + result.RoleDefinitions = append(result.RoleDefinitions, r.RoleDefinitions...) + result.VirtualMachines = append(result.VirtualMachines, r.VirtualMachines...) + } + result.Principals = slices.DeduplicateKey(result.Principals, azurePrincipalsKey) + result.RoleAssignments = slices.DeduplicateKey(result.RoleAssignments, azureRoleAssignKey) + result.RoleDefinitions = slices.DeduplicateKey(result.RoleDefinitions, azureRoleDefKey) + result.VirtualMachines = slices.DeduplicateKey(result.VirtualMachines, azureVmKey) + return result +} + +// newResourceList creates a new resource list message +func newResourceList() *accessgraphv1alpha.AzureResourceList { + return &accessgraphv1alpha.AzureResourceList{ + Resources: make([]*accessgraphv1alpha.AzureResource, 0), + } +} + +// ReconcileResults compares previously and currently fetched results and determines which resources to upsert and +// which to delete. +func ReconcileResults(old *Resources, new *Resources) (upsert, delete *accessgraphv1alpha.AzureResourceList) { + upsert, delete = newResourceList(), newResourceList() + reconciledResources := []*reconcilePair{ + reconcile(old.Principals, new.Principals, azurePrincipalsKey, azurePrincipalsWrap), + reconcile(old.RoleAssignments, new.RoleAssignments, azureRoleAssignKey, azureRoleAssignWrap), + reconcile(old.RoleDefinitions, new.RoleDefinitions, azureRoleDefKey, azureRoleDefWrap), + reconcile(old.VirtualMachines, new.VirtualMachines, azureVmKey, azureVmWrap), + } + for _, res := range reconciledResources { + upsert.Resources = append(upsert.Resources, res.upsert.Resources...) + delete.Resources = append(delete.Resources, res.delete.Resources...) + } + return upsert, delete +} + +// reconcilePair contains the Azure resources to upsert and delete +type reconcilePair struct { + upsert, delete *accessgraphv1alpha.AzureResourceList +} + +// reconcile compares old and new items to build a list of resources to upsert and delete in the Access Graph +func reconcile[T proto.Message]( + oldItems []T, + newItems []T, + keyFn func(T) string, + wrapFn func(T) *accessgraphv1alpha.AzureResource, +) *reconcilePair { + // Remove duplicates from the new items + newItems = slices.DeduplicateKey(newItems, keyFn) + upsertRes := newResourceList() + deleteRes := newResourceList() + + // Delete all old items if there are no new items + if len(newItems) == 0 { + for _, item := range oldItems { + deleteRes.Resources = append(deleteRes.Resources, wrapFn(item)) + } + return &reconcilePair{upsertRes, deleteRes} + } + + // Create all new items if there are no old items + if len(oldItems) == 0 { + for _, item := range newItems { + upsertRes.Resources = append(upsertRes.Resources, wrapFn(item)) + } + return &reconcilePair{upsertRes, deleteRes} + } + + // Map old and new items by their key + oldMap := make(map[string]T, len(oldItems)) + for _, item := range oldItems { + oldMap[keyFn(item)] = item + } + newMap := make(map[string]T, len(newItems)) + for _, item := range newItems { + newMap[keyFn(item)] = item + } + + // Append new or modified items to the upsert list + for _, item := range newItems { + if oldItem, ok := oldMap[keyFn(item)]; !ok || !proto.Equal(oldItem, item) { + upsertRes.Resources = append(upsertRes.Resources, wrapFn(item)) + } + } + + // Append removed items to the delete list + for _, item := range oldItems { + if _, ok := newMap[keyFn(item)]; !ok { + deleteRes.Resources = append(deleteRes.Resources, wrapFn(item)) + } + } + return &reconcilePair{upsertRes, deleteRes} +} + +func azurePrincipalsKey(user *accessgraphv1alpha.AzurePrincipal) string { + return fmt.Sprintf("%s:%s", user.SubscriptionId, user.Id) +} + +func azurePrincipalsWrap(principal *accessgraphv1alpha.AzurePrincipal) *accessgraphv1alpha.AzureResource { + return &accessgraphv1alpha.AzureResource{Resource: &accessgraphv1alpha.AzureResource_Principal{Principal: principal}} +} + +func azureRoleAssignKey(roleAssign *accessgraphv1alpha.AzureRoleAssignment) string { + return fmt.Sprintf("%s:%s", roleAssign.SubscriptionId, roleAssign.Id) +} + +func azureRoleAssignWrap(roleAssign *accessgraphv1alpha.AzureRoleAssignment) *accessgraphv1alpha.AzureResource { + return &accessgraphv1alpha.AzureResource{Resource: &accessgraphv1alpha.AzureResource_RoleAssignment{RoleAssignment: roleAssign}} +} + +func azureRoleDefKey(roleDef *accessgraphv1alpha.AzureRoleDefinition) string { + return fmt.Sprintf("%s:%s", roleDef.SubscriptionId, roleDef.Id) +} + +func azureRoleDefWrap(roleDef *accessgraphv1alpha.AzureRoleDefinition) *accessgraphv1alpha.AzureResource { + return &accessgraphv1alpha.AzureResource{Resource: &accessgraphv1alpha.AzureResource_RoleDefinition{RoleDefinition: roleDef}} +} + +func azureVmKey(vm *accessgraphv1alpha.AzureVirtualMachine) string { + return fmt.Sprintf("%s:%s", vm.SubscriptionId, vm.Id) +} + +func azureVmWrap(vm *accessgraphv1alpha.AzureVirtualMachine) *accessgraphv1alpha.AzureResource { + return &accessgraphv1alpha.AzureResource{Resource: &accessgraphv1alpha.AzureResource_VirtualMachine{VirtualMachine: vm}} +} diff --git a/lib/srv/discovery/fetchers/azure-sync/reconcile_test.go b/lib/srv/discovery/fetchers/azure-sync/reconcile_test.go new file mode 100644 index 0000000000000..3652c4963218c --- /dev/null +++ b/lib/srv/discovery/fetchers/azure-sync/reconcile_test.go @@ -0,0 +1,197 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package azuresync + +import ( + "testing" + + "github.com/stretchr/testify/require" + + accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" +) + +func TestReconcileResults(t *testing.T) { + principals := generatePrincipals() + roleDefs := generateRoleDefs() + roleAssigns := generateRoleAssigns() + vms := generateVms() + + tests := []struct { + name string + oldResults *Resources + newResults *Resources + expectedUpserts *accessgraphv1alpha.AzureResourceList + expectedDeletes *accessgraphv1alpha.AzureResourceList + }{ + // Overlapping old and new results + { + name: "OverlapOldAndNewResults", + oldResults: &Resources{ + Principals: principals[0:2], + RoleDefinitions: roleDefs[0:2], + RoleAssignments: roleAssigns[0:2], + VirtualMachines: vms[0:2], + }, + newResults: &Resources{ + Principals: principals[1:3], + RoleDefinitions: roleDefs[1:3], + RoleAssignments: roleAssigns[1:3], + VirtualMachines: vms[1:3], + }, + expectedUpserts: generateExpected(principals[2:3], roleDefs[2:3], roleAssigns[2:3], vms[2:3]), + expectedDeletes: generateExpected(principals[0:1], roleDefs[0:1], roleAssigns[0:1], vms[0:1]), + }, + // Completely new results + { + name: "CompletelyNewResults", + oldResults: &Resources{ + Principals: nil, + RoleDefinitions: nil, + RoleAssignments: nil, + VirtualMachines: nil, + }, + newResults: &Resources{ + Principals: principals[1:3], + RoleDefinitions: roleDefs[1:3], + RoleAssignments: roleAssigns[1:3], + VirtualMachines: vms[1:3], + }, + expectedUpserts: generateExpected(principals[1:3], roleDefs[1:3], roleAssigns[1:3], vms[1:3]), + expectedDeletes: generateExpected(nil, nil, nil, nil), + }, + // No new results + { + name: "NoNewResults", + oldResults: &Resources{ + Principals: principals[1:3], + RoleDefinitions: roleDefs[1:3], + RoleAssignments: roleAssigns[1:3], + VirtualMachines: vms[1:3], + }, + newResults: &Resources{ + Principals: nil, + RoleDefinitions: nil, + RoleAssignments: nil, + VirtualMachines: nil, + }, + expectedUpserts: generateExpected(nil, nil, nil, nil), + expectedDeletes: generateExpected(principals[1:3], roleDefs[1:3], roleAssigns[1:3], vms[1:3]), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + upserts, deletes := ReconcileResults(tt.oldResults, tt.newResults) + require.ElementsMatch(t, upserts.Resources, tt.expectedUpserts.Resources) + require.ElementsMatch(t, deletes.Resources, tt.expectedDeletes.Resources) + }) + } + +} + +func generateExpected( + principals []*accessgraphv1alpha.AzurePrincipal, + roleDefs []*accessgraphv1alpha.AzureRoleDefinition, + roleAssigns []*accessgraphv1alpha.AzureRoleAssignment, + vms []*accessgraphv1alpha.AzureVirtualMachine, +) *accessgraphv1alpha.AzureResourceList { + resList := &accessgraphv1alpha.AzureResourceList{ + Resources: make([]*accessgraphv1alpha.AzureResource, 0), + } + for _, principal := range principals { + resList.Resources = append(resList.Resources, azurePrincipalsWrap(principal)) + } + for _, roleDef := range roleDefs { + resList.Resources = append(resList.Resources, azureRoleDefWrap(roleDef)) + } + for _, roleAssign := range roleAssigns { + resList.Resources = append(resList.Resources, azureRoleAssignWrap(roleAssign)) + } + for _, vm := range vms { + resList.Resources = append(resList.Resources, azureVmWrap(vm)) + } + return resList +} + +func generatePrincipals() []*accessgraphv1alpha.AzurePrincipal { + return []*accessgraphv1alpha.AzurePrincipal{ + { + Id: "/principals/foo", + DisplayName: "userFoo", + }, + { + Id: "/principals/bar", + DisplayName: "userBar", + }, + { + Id: "/principals/charles", + DisplayName: "userCharles", + }, + } +} + +func generateRoleDefs() []*accessgraphv1alpha.AzureRoleDefinition { + return []*accessgraphv1alpha.AzureRoleDefinition{ + { + Id: "/roledefinitions/foo", + Name: "roleFoo", + }, + { + Id: "/roledefinitions/bar", + Name: "roleBar", + }, + { + Id: "/roledefinitions/charles", + Name: "roleCharles", + }, + } +} + +func generateRoleAssigns() []*accessgraphv1alpha.AzureRoleAssignment { + return []*accessgraphv1alpha.AzureRoleAssignment{ + { + Id: "/roleassignments/foo", + PrincipalId: "userFoo", + }, + { + Id: "/roleassignments/bar", + PrincipalId: "userBar", + }, + { + Id: "/roleassignments/charles", + PrincipalId: "userCharles", + }, + } +} + +func generateVms() []*accessgraphv1alpha.AzureVirtualMachine { + return []*accessgraphv1alpha.AzureVirtualMachine{ + { + Id: "/vms/foo", + Name: "userFoo", + }, + { + Id: "/vms/bar", + Name: "userBar", + }, + { + Id: "/vms/charles", + Name: "userCharles", + }, + } +} diff --git a/lib/utils/slices/slices.go b/lib/utils/slices/slices.go index 3c33c0baf8710..077911c8738cf 100644 --- a/lib/utils/slices/slices.go +++ b/lib/utils/slices/slices.go @@ -57,3 +57,17 @@ func FromPointers[T any](in []*T) []T { } return out } + +// DeduplicateKey returns a deduplicated slice by comparing key values from the key function +func DeduplicateKey[T any](s []T, key func(T) string) []T { + out := make([]T, 0, len(s)) + seen := make(map[string]struct{}) + for _, v := range s { + if _, ok := seen[key(v)]; ok { + continue + } + seen[key(v)] = struct{}{} + out = append(out, v) + } + return out +} diff --git a/lib/utils/slices/slices_test.go b/lib/utils/slices/slices_test.go index 1f031d21eca49..b0fe86a1bfd2e 100644 --- a/lib/utils/slices/slices_test.go +++ b/lib/utils/slices/slices_test.go @@ -19,6 +19,7 @@ package slices import ( + "fmt" "strings" "testing" @@ -92,3 +93,92 @@ func TestFilterMapUnique(t *testing.T) { require.Equal(t, expected, got) }) } + +// TestDuplicateKey tests slice deduplication via key function +func TestDeduplicateKey(t *testing.T) { + t.Parallel() + + stringTests := []struct { + name string + slice []string + keyFn func(string) string + expected []string + }{ + { + name: "EmptyStringSlice", + slice: []string{}, + keyFn: func(s string) string { return s }, + expected: []string{}, + }, + { + name: "NoStringDuplicates", + slice: []string{"foo", "bar", "baz"}, + keyFn: func(s string) string { return s }, + expected: []string{"foo", "bar", "baz"}, + }, + { + name: "StringDuplicates", + slice: []string{"foo", "bar", "bar", "bar", "baz", "baz"}, + keyFn: func(s string) string { return s }, + expected: []string{"foo", "bar", "baz"}, + }, + { + name: "StringDuplicatesWeirdKeyFn", + slice: []string{"foo", "bar", "bar", "bar", "baz", "baz"}, + keyFn: func(s string) string { return "huh" }, + expected: []string{"foo"}, + }, + } + for _, tt := range stringTests { + t.Run(tt.name, func(t *testing.T) { + res := DeduplicateKey(tt.slice, tt.keyFn) + require.Equal(t, tt.expected, res) + }) + } + + type dedupeStruct struct { + a string + b int + c bool + } + dedupeStructKeyFn := func(d dedupeStruct) string { return fmt.Sprintf("%s:%d:%v", d.a, d.b, d.c) } + structTests := []struct { + name string + slice []dedupeStruct + keyFn func(d dedupeStruct) string + expected []dedupeStruct + }{ + { + name: "EmptySlice", + slice: []dedupeStruct{}, + keyFn: dedupeStructKeyFn, + expected: []dedupeStruct{}, + }, + { + name: "NoStructDuplicates", + slice: []dedupeStruct{ + {a: "foo", b: 1, c: true}, + {a: "foo", b: 1, c: false}, + {a: "foo", b: 2, c: true}, + {a: "bar", b: 1, c: true}, + {a: "bar", b: 1, c: false}, + {a: "bar", b: 2, c: true}, + }, + keyFn: dedupeStructKeyFn, + expected: []dedupeStruct{ + {a: "foo", b: 1, c: true}, + {a: "foo", b: 1, c: false}, + {a: "foo", b: 2, c: true}, + {a: "bar", b: 1, c: true}, + {a: "bar", b: 1, c: false}, + {a: "bar", b: 2, c: true}, + }, + }, + } + for _, tt := range structTests { + t.Run(tt.name, func(t *testing.T) { + res := DeduplicateKey(tt.slice, tt.keyFn) + require.Equal(t, tt.expected, res) + }) + } +}