From f789524433b152796c619557be901452e20193d9 Mon Sep 17 00:00:00 2001 From: umagnus Date: Wed, 13 Dec 2023 03:12:00 +0000 Subject: [PATCH] fix --- pkg/blob/controllerserver.go | 76 +++++++++++++++++++++++-------- pkg/blob/controllerserver_test.go | 12 ++--- 2 files changed, 64 insertions(+), 24 deletions(-) diff --git a/pkg/blob/controllerserver.go b/pkg/blob/controllerserver.go index d6ba79375..a0aa25fc1 100644 --- a/pkg/blob/controllerserver.go +++ b/pkg/blob/controllerserver.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "net/url" + "os" "os/exec" "strconv" "strings" @@ -92,7 +93,6 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) var matchTags, useDataPlaneAPI, getLatestAccountKey bool var softDeleteBlobs, softDeleteContainers int32 var vnetResourceIDs []string - var storageIdentityClientID string var err error // set allowBlobPublicAccess as false by default allowBlobPublicAccess := pointer.Bool(false) @@ -176,7 +176,6 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) case serverNameField: case storageAuthTypeField: case storageIdentityClientIDField: - storageIdentityClientID = v case storageIdentityObjectIDField: case storageIdentityResourceIDField: case msiEndpointField: @@ -418,7 +417,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) return nil, status.Errorf(codes.Internal, "failed to GetStorageAccesskey on account(%s) rg(%s), error: %v", accountOptions.Name, accountOptions.ResourceGroup, err) } } - if err := d.copyVolume(ctx, req, accountKey, validContainerName, storageEndpointSuffix, storageIdentityClientID); err != nil { + if err := d.copyVolume(ctx, req, accountKey, validContainerName, storageEndpointSuffix); err != nil { return nil, err } } else { @@ -713,7 +712,7 @@ func (d *Driver) DeleteBlobContainer(ctx context.Context, subsID, resourceGroupN } // CopyBlobContainer copies a blob container in the same storage account -func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeRequest, accountKey, dstContainerName, storageEndpointSuffix, storageIdentityClientID string) error { +func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeRequest, accountKey, dstContainerName, storageEndpointSuffix string) error { var sourceVolumeID string if req.GetVolumeContentSource() != nil && req.GetVolumeContentSource().GetVolume() != nil { sourceVolumeID = req.GetVolumeContentSource().GetVolume().GetVolumeId() @@ -728,7 +727,12 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque } var accountSasToken string - if storageIdentityClientID == "" { + if d.cloud.Config.AzureAuthConfig.UseFederatedWorkloadIdentityExtension || d.cloud.Config.AzureAuthConfig.UseManagedIdentityExtension { + err = d.authorizeAzcopyBySecurityPrincipal() + if err != nil { + return err + } + } else { klog.V(2).Infof("generate sas token for account(%s)", accountName) accountSasToken, err = generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes) if err != nil { @@ -757,17 +761,6 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque return err case util.AzcopyJobNotFound: klog.V(2).Infof("copy blob container %s to %s", srcContainerName, dstContainerName) - if storageIdentityClientID != "" { - klog.V(2).Infof("use msi client id to authorize azcopy") - _, err = exec.Command("export", "AZCOPY_AUTO_LOGIN_TYPE=MSI").CombinedOutput() - if err != nil { - return err - } - _, err = exec.Command("export", fmt.Sprintf("AZCOPY_MSI_CLIENT_ID=%s", storageIdentityClientID)).CombinedOutput() - if err != nil { - return err - } - } out, copyErr := exec.Command("azcopy", "copy", srcPath, dstPath, "--recursive", "--check-length=false").CombinedOutput() if copyErr != nil { klog.Warningf("CopyBlobContainer(%s, %s, %s) failed with error(%v): %v", resourceGroupName, accountName, dstPath, copyErr, string(out)) @@ -783,18 +776,65 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque } // copyVolume copies a volume form volume or snapshot, snapshot is not supported now -func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, accountKey, dstContainerName, storageEndpointSuffix, storageIdentityClientID string) error { +func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, accountKey, dstContainerName, storageEndpointSuffix string) error { vs := req.VolumeContentSource switch vs.Type.(type) { case *csi.VolumeContentSource_Snapshot: return status.Errorf(codes.InvalidArgument, "copy volume from volumeSnapshot is not supported") case *csi.VolumeContentSource_Volume: - return d.copyBlobContainer(ctx, req, accountKey, dstContainerName, storageEndpointSuffix, storageIdentityClientID) + return d.copyBlobContainer(ctx, req, accountKey, dstContainerName, storageEndpointSuffix) default: return status.Errorf(codes.InvalidArgument, "%v is not a proper volume source", vs) } } +func (d *Driver) authorizeAzcopyBySecurityPrincipal() error { + if d.cloud.Config.AzureAuthConfig.UseFederatedWorkloadIdentityExtension { + klog.V(2).Infof("use service principal to authorize azcopy") + os.Setenv("AZCOPY_AUTO_LOGIN_TYPE", "SPN") + if err := os.Setenv("AZCOPY_AUTO_LOGIN_TYPE", "SPN"); err != nil { + klog.Errorf("failed to set AZCOPY_AUTO_LOGIN_TYPE=SPN, error: %v", err) + return err + } + if d.cloud.Config.AADClientID == "" || d.cloud.Config.AADClientSecret == "" || d.cloud.Config.TenantID == "" { + return fmt.Errorf("AADClientID: %s, AADClientSecret: %s and TenantID: %s must be set when use federated workload identity extension", d.cloud.Config.AADClientID, d.cloud.Config.AADClientSecret, d.cloud.Config.TenantID) + } + if err := os.Setenv("AZCOPY_SPA_APPLICATION_ID", d.cloud.Config.AADClientID); err != nil { + klog.Errorf("failed to set AZCOPY_SPA_APPLICATION_ID, error: %v", err) + return err + } + if err := os.Setenv("AZCOPY_SPA_CLIENT_SECRET", d.cloud.Config.AADClientSecret); err != nil { + klog.Errorf("failed to set AZCOPY_SPA_CLIENT_SECRET, error: %v", err) + return err + } + if err := os.Setenv("AZCOPY_TENANT_ID", d.cloud.Config.TenantID); err != nil { + klog.Errorf("failed to set AZCOPY_TENANT_ID, error: %v", err) + return err + } + klog.V(2).Infof("set AZCOPY_AUTO_LOGIN_TYPE=SPN, AZCOPY_SPA_APPLICATION_ID, AZCOPY_SPA_CLIENT_SECRET and AZCOPY_TENANT_ID successfully") + return nil + } + if d.cloud.Config.AzureAuthConfig.UseManagedIdentityExtension { + klog.V(2).Infof("use managed identity to authorize azcopy") + if err := os.Setenv("AZCOPY_AUTO_LOGIN_TYPE", "MSI"); err != nil { + klog.Errorf("failed to set AZCOPY_AUTO_LOGIN_TYPE=MSI, error: %v", err) + } + klog.V(2).Infof("set AZCOPY_AUTO_LOGIN_TYPE=MSI successfully") + if len(d.cloud.Config.UserAssignedIdentityID) > 0 { + klog.V(2).Infof("authorize by using a user-assigned managed identity") + if err := os.Setenv("AZCOPY_MSI_CLIENT_ID", d.cloud.Config.UserAssignedIdentityID); err != nil { + klog.Errorf("failed to set AZCOPY_MSI_CLIENT_ID, error: %v", err) + return err + } + klog.V(2).Infof("set AZCOPY_MSI_CLIENT_ID successfully") + return nil + } + klog.V(2).Infof("authorize by using a system-assigned managed identity") + return nil + } + return fmt.Errorf("useFederatedWorkloadIdentityExtension or useManagedIdentityExtension must be set to true") +} + // isValidVolumeCapabilities validates the given VolumeCapability array is valid func isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) error { if len(volCaps) == 0 { diff --git a/pkg/blob/controllerserver_test.go b/pkg/blob/controllerserver_test.go index 080bd9ec9..0f7e03747 100644 --- a/pkg/blob/controllerserver_test.go +++ b/pkg/blob/controllerserver_test.go @@ -1496,7 +1496,7 @@ func TestCopyVolume(t *testing.T) { ctx := context.Background() expectedErr := status.Errorf(codes.InvalidArgument, "copy volume from volumeSnapshot is not supported") - err := d.copyVolume(ctx, req, "", "", "core.windows.net", "") + err := d.copyVolume(ctx, req, "", "", "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1528,7 +1528,7 @@ func TestCopyVolume(t *testing.T) { ctx := context.Background() expectedErr := status.Errorf(codes.NotFound, "error parsing volume id: \"unit-test\", should at least contain two #") - err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net", "") + err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1560,7 +1560,7 @@ func TestCopyVolume(t *testing.T) { ctx := context.Background() expectedErr := fmt.Errorf("srcContainerName() or dstContainerName(dstContainer) is empty") - err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net", "") + err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1592,7 +1592,7 @@ func TestCopyVolume(t *testing.T) { ctx := context.Background() expectedErr := fmt.Errorf("srcContainerName(fileshare) or dstContainerName() is empty") - err := d.copyVolume(ctx, req, "", "", "core.windows.net", "") + err := d.copyVolume(ctx, req, "", "", "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1636,7 +1636,7 @@ func TestCopyVolume(t *testing.T) { ctx := context.Background() var expectedErr error - err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net", "") + err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1681,7 +1681,7 @@ func TestCopyVolume(t *testing.T) { ctx := context.Background() var expectedErr error - err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net", "") + err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) }