From ef3d7bf7c20154a2104d043e6dde8abd45a7a50e Mon Sep 17 00:00:00 2001 From: Jack Francis Date: Mon, 8 Apr 2024 17:02:47 -0700 Subject: [PATCH] enable per-sub msi client Signed-off-by: Jack Francis --- azure/services/identities/client.go | 13 +++++++++++++ .../services/virtualmachines/virtualmachines.go | 13 ++++++++++++- .../virtualmachines/virtualmachines_test.go | 10 ++++++++-- controllers/azurejson_machine_controller.go | 16 ++++++++++++++-- controllers/azurejson_machinepool_controller.go | 16 ++++++++++++++-- .../azurejson_machinepool_controller_test.go | 4 ++-- .../azurejson_machinetemplate_controller.go | 16 ++++++++++++++-- 7 files changed, 77 insertions(+), 11 deletions(-) diff --git a/azure/services/identities/client.go b/azure/services/identities/client.go index ca406d2914f..c72c98b5bd4 100644 --- a/azure/services/identities/client.go +++ b/azure/services/identities/client.go @@ -51,6 +51,19 @@ func NewClient(auth azure.Authorizer) (Client, error) { return &AzureClient{factory.NewUserAssignedIdentitiesClient()}, nil } +// NewClientBySub creates a new MSI client with a given subscriptionID. +func NewClientBySub(auth azure.Authorizer, subscriptionID string) (Client, error) { + opts, err := azure.ARMClientOptions(auth.CloudEnvironment()) + if err != nil { + return nil, errors.Wrap(err, "failed to create identities client options") + } + factory, err := armmsi.NewClientFactory(subscriptionID, auth.Token(), opts) + if err != nil { + return nil, errors.Wrap(err, "failed to create armmsi client factory") + } + return &AzureClient{factory.NewUserAssignedIdentitiesClient()}, nil +} + // Get returns a managed service identity. func (ac *AzureClient) Get(ctx context.Context, resourceGroupName, name string) (armmsi.Identity, error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "identities.AzureClient.Get") diff --git a/azure/services/virtualmachines/virtualmachines.go b/azure/services/virtualmachines/virtualmachines.go index 32c7a74d7bb..5d0986021d1 100644 --- a/azure/services/virtualmachines/virtualmachines.go +++ b/azure/services/virtualmachines/virtualmachines.go @@ -175,7 +175,18 @@ func (s *Service) checkUserAssignedIdentities(ctx context.Context, specIdentitie // Create a map of the expected identities. The ProviderID is converted to match the format of the VM identity. for _, expectedIdentity := range specIdentities { - expectedClientID, err := s.identitiesGetter.GetClientID(ctx, expectedIdentity.ProviderID) + identitiesClient := s.identitiesGetter + parsed, err := azureutil.ParseResourceID(expectedIdentity.ProviderID) + if err != nil { + return err + } + if parsed.SubscriptionID != s.Scope.SubscriptionID() { + identitiesClient, err = identities.NewClientBySub(s.Scope, parsed.SubscriptionID) + if err != nil { + return errors.Wrapf(err, "failed to create identities client from subscription ID %s", parsed.SubscriptionID) + } + } + expectedClientID, err := identitiesClient.GetClientID(ctx, expectedIdentity.ProviderID) if err != nil { return errors.Wrap(err, "failed to get client ID") } diff --git a/azure/services/virtualmachines/virtualmachines_test.go b/azure/services/virtualmachines/virtualmachines_test.go index fd663f99cca..85ac15234ab 100644 --- a/azure/services/virtualmachines/virtualmachines_test.go +++ b/azure/services/virtualmachines/virtualmachines_test.go @@ -114,10 +114,10 @@ var ( }, } fakeUserAssignedIdentity = infrav1.UserAssignedIdentity{ - ProviderID: "fake-provider-id", + ProviderID: "azure:///subscriptions/123/resourceGroups/test-rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fake-provider-id", } fakeUserAssignedIdentity2 = infrav1.UserAssignedIdentity{ - ProviderID: "fake-provider-id-2", + ProviderID: "azure:///subscriptions/123/resourceGroups/test-rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fake-provider-id-2", } ) @@ -335,6 +335,7 @@ func TestCheckUserAssignedIdentities(t *testing.T) { specIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity}, actualIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity}, expect: func(s *mock_virtualmachines.MockVMScopeMockRecorder, i *mock_identities.MockClientMockRecorder) { + s.SubscriptionID().Return("123") i.GetClientID(gomockinternal.AContext(), fakeUserAssignedIdentity.ProviderID).AnyTimes().Return(fakeUserAssignedIdentity.ProviderID, nil) }, expectedError: "", @@ -344,6 +345,7 @@ func TestCheckUserAssignedIdentities(t *testing.T) { specIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity, fakeUserAssignedIdentity2}, actualIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity}, expect: func(s *mock_virtualmachines.MockVMScopeMockRecorder, i *mock_identities.MockClientMockRecorder) { + s.SubscriptionID().AnyTimes().Return("123") i.GetClientID(gomockinternal.AContext(), fakeUserAssignedIdentity.ProviderID).AnyTimes().Return(fakeUserAssignedIdentity.ProviderID, nil) i.GetClientID(gomockinternal.AContext(), fakeUserAssignedIdentity2.ProviderID).AnyTimes().Return(fakeUserAssignedIdentity2.ProviderID, nil) s.SetConditionFalse(infrav1.VMIdentitiesReadyCondition, infrav1.UserAssignedIdentityMissingReason, clusterv1.ConditionSeverityWarning, "VM is missing expected user assigned identity with client ID: "+fakeUserAssignedIdentity2.ProviderID).Times(1) @@ -355,6 +357,7 @@ func TestCheckUserAssignedIdentities(t *testing.T) { specIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity}, actualIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity, fakeUserAssignedIdentity2}, expect: func(s *mock_virtualmachines.MockVMScopeMockRecorder, i *mock_identities.MockClientMockRecorder) { + s.SubscriptionID().Return("123") i.GetClientID(gomockinternal.AContext(), fakeUserAssignedIdentity.ProviderID).AnyTimes().Return(fakeUserAssignedIdentity.ProviderID, nil) }, expectedError: "", @@ -364,6 +367,7 @@ func TestCheckUserAssignedIdentities(t *testing.T) { specIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity}, actualIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity2}, expect: func(s *mock_virtualmachines.MockVMScopeMockRecorder, i *mock_identities.MockClientMockRecorder) { + s.SubscriptionID().Return("123") i.GetClientID(gomockinternal.AContext(), fakeUserAssignedIdentity.ProviderID).AnyTimes().Return(fakeUserAssignedIdentity.ProviderID, nil) s.SetConditionFalse(infrav1.VMIdentitiesReadyCondition, infrav1.UserAssignedIdentityMissingReason, clusterv1.ConditionSeverityWarning, "VM is missing expected user assigned identity with client ID: "+fakeUserAssignedIdentity.ProviderID).Times(1) }, @@ -374,6 +378,7 @@ func TestCheckUserAssignedIdentities(t *testing.T) { specIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity, fakeUserAssignedIdentity}, actualIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity}, expect: func(s *mock_virtualmachines.MockVMScopeMockRecorder, i *mock_identities.MockClientMockRecorder) { + s.SubscriptionID().AnyTimes().Return("123") i.GetClientID(gomockinternal.AContext(), fakeUserAssignedIdentity.ProviderID).AnyTimes().Return(fakeUserAssignedIdentity.ProviderID, nil) }, expectedError: "", @@ -383,6 +388,7 @@ func TestCheckUserAssignedIdentities(t *testing.T) { specIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity}, actualIdentities: []infrav1.UserAssignedIdentity{fakeUserAssignedIdentity}, expect: func(s *mock_virtualmachines.MockVMScopeMockRecorder, i *mock_identities.MockClientMockRecorder) { + s.SubscriptionID().Return("123") i.GetClientID(gomockinternal.AContext(), fakeUserAssignedIdentity.ProviderID).AnyTimes().Return("", errors.New("failed to get client id")) }, expectedError: "failed to get client id", diff --git a/controllers/azurejson_machine_controller.go b/controllers/azurejson_machine_controller.go index 6a49a3f8a08..2eb5e5e4f03 100644 --- a/controllers/azurejson_machine_controller.go +++ b/controllers/azurejson_machine_controller.go @@ -31,6 +31,7 @@ import ( infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure/scope" "sigs.k8s.io/cluster-api-provider-azure/azure/services/identities" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -215,11 +216,22 @@ func (r *AzureJSONMachineReconciler) Reconcile(ctx context.Context, req ctrl.Req // Construct secret for this machine userAssignedIdentityIfExists := "" if len(azureMachine.Spec.UserAssignedIdentities) > 0 { - idsClient, err := identities.NewClient(clusterScope) + var identitiesClient identities.Client + identitiesClient, err := identities.NewClient(clusterScope) if err != nil { return reconcile.Result{}, errors.Wrap(err, "failed to create identities client") } - userAssignedIdentityIfExists, err = idsClient.GetClientID( + parsed, err := azureutil.ParseResourceID(azureMachine.Spec.UserAssignedIdentities[0].ProviderID) + if err != nil { + return reconcile.Result{}, errors.Wrapf(err, "failed to parse ProviderID %s", azureMachine.Spec.UserAssignedIdentities[0].ProviderID) + } + if parsed.SubscriptionID != clusterScope.SubscriptionID() { + identitiesClient, err = identities.NewClientBySub(clusterScope, parsed.SubscriptionID) + if err != nil { + return reconcile.Result{}, errors.Wrapf(err, "failed to create identities client from subscription ID %s", parsed.SubscriptionID) + } + } + userAssignedIdentityIfExists, err = identitiesClient.GetClientID( ctx, azureMachine.Spec.UserAssignedIdentities[0].ProviderID) if err != nil { return reconcile.Result{}, errors.Wrap(err, "failed to get user-assigned identity ClientID") diff --git a/controllers/azurejson_machinepool_controller.go b/controllers/azurejson_machinepool_controller.go index 09d5af0773a..9eac1dc5778 100644 --- a/controllers/azurejson_machinepool_controller.go +++ b/controllers/azurejson_machinepool_controller.go @@ -29,6 +29,7 @@ import ( infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure/services/identities" infrav1exp "sigs.k8s.io/cluster-api-provider-azure/exp/api/v1beta1" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -148,11 +149,22 @@ func (r *AzureJSONMachinePoolReconciler) Reconcile(ctx context.Context, req ctrl // Construct secret for this machine userAssignedIdentityIfExists := "" if len(azureMachinePool.Spec.UserAssignedIdentities) > 0 { - idsClient, err := getClient(clusterScope) + var identitiesClient identities.Client + identitiesClient, err := getClient(clusterScope) if err != nil { return reconcile.Result{}, errors.Wrap(err, "failed to create identities client") } - userAssignedIdentityIfExists, err = idsClient.GetClientID( + parsed, err := azureutil.ParseResourceID(azureMachinePool.Spec.UserAssignedIdentities[0].ProviderID) + if err != nil { + return reconcile.Result{}, errors.Wrapf(err, "failed to parse ProviderID %s", azureMachinePool.Spec.UserAssignedIdentities[0].ProviderID) + } + if parsed.SubscriptionID != clusterScope.SubscriptionID() { + identitiesClient, err = identities.NewClientBySub(clusterScope, parsed.SubscriptionID) + if err != nil { + return reconcile.Result{}, errors.Wrapf(err, "failed to create identities client from subscription ID %s", parsed.SubscriptionID) + } + } + userAssignedIdentityIfExists, err = identitiesClient.GetClientID( ctx, azureMachinePool.Spec.UserAssignedIdentities[0].ProviderID) if err != nil { return reconcile.Result{}, errors.Wrap(err, "failed to get user-assigned identity ClientID") diff --git a/controllers/azurejson_machinepool_controller_test.go b/controllers/azurejson_machinepool_controller_test.go index ec7a5ad1694..c3c6e88dfed 100644 --- a/controllers/azurejson_machinepool_controller_test.go +++ b/controllers/azurejson_machinepool_controller_test.go @@ -271,7 +271,7 @@ func TestAzureJSONPoolReconcilerUserAssignedIdentities(t *testing.T) { Spec: infrav1exp.AzureMachinePoolSpec{ UserAssignedIdentities: []infrav1.UserAssignedIdentity{ { - ProviderID: "fake-id", + ProviderID: "azure:///subscriptions/123/resourceGroups/test-rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fake-provider-id", }, }, }, @@ -378,7 +378,7 @@ func TestAzureJSONPoolReconcilerUserAssignedIdentities(t *testing.T) { Recorder: record.NewFakeRecorder(42), Timeouts: reconciler.Timeouts{}, } - id := "fake-id" + id := "azure:///subscriptions/123/resourceGroups/test-rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fake-provider-id" getClient = func(auth azure.Authorizer) (identities.Client, error) { mockClient := mock_identities.NewMockClient(ctrlr) mockClient.EXPECT().GetClientID(gomock.Any(), gomock.Any()).Return(id, nil) diff --git a/controllers/azurejson_machinetemplate_controller.go b/controllers/azurejson_machinetemplate_controller.go index 359ae971a57..baea0ee8e18 100644 --- a/controllers/azurejson_machinetemplate_controller.go +++ b/controllers/azurejson_machinetemplate_controller.go @@ -30,6 +30,7 @@ import ( infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure/scope" "sigs.k8s.io/cluster-api-provider-azure/azure/services/identities" + azureutil "sigs.k8s.io/cluster-api-provider-azure/util/azure" "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -175,11 +176,22 @@ func (r *AzureJSONTemplateReconciler) Reconcile(ctx context.Context, req ctrl.Re // Construct secret for this machine template userAssignedIdentityIfExists := "" if len(azureMachineTemplate.Spec.Template.Spec.UserAssignedIdentities) > 0 { - idsClient, err := identities.NewClient(clusterScope) + var identitiesClient identities.Client + identitiesClient, err := identities.NewClient(clusterScope) if err != nil { return reconcile.Result{}, errors.Wrap(err, "failed to create identities client") } - userAssignedIdentityIfExists, err = idsClient.GetClientID( + parsed, err := azureutil.ParseResourceID(azureMachineTemplate.Spec.Template.Spec.UserAssignedIdentities[0].ProviderID) + if err != nil { + return reconcile.Result{}, errors.Wrapf(err, "failed to parse ProviderID %s", azureMachineTemplate.Spec.Template.Spec.UserAssignedIdentities[0].ProviderID) + } + if parsed.SubscriptionID != clusterScope.SubscriptionID() { + identitiesClient, err = identities.NewClientBySub(clusterScope, parsed.SubscriptionID) + if err != nil { + return reconcile.Result{}, errors.Wrapf(err, "failed to create identities client from subscription ID %s", parsed.SubscriptionID) + } + } + userAssignedIdentityIfExists, err = identitiesClient.GetClientID( ctx, azureMachineTemplate.Spec.Template.Spec.UserAssignedIdentities[0].ProviderID) if err != nil { return reconcile.Result{}, errors.Wrap(err, "failed to get user-assigned identity ClientID")