Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
umagnus committed Dec 13, 2023
1 parent ec2f7a0 commit f789524
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 24 deletions.
76 changes: 58 additions & 18 deletions pkg/blob/controllerserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"net/url"
"os"
"os/exec"
"strconv"
"strings"
Expand Down Expand Up @@ -92,7 +93,6 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
var matchTags, useDataPlaneAPI, getLatestAccountKey bool
var softDeleteBlobs, softDeleteContainers int32
var vnetResourceIDs []string
var storageIdentityClientID string
var err error
// set allowBlobPublicAccess as false by default
allowBlobPublicAccess := pointer.Bool(false)
Expand Down Expand Up @@ -176,7 +176,6 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
case serverNameField:
case storageAuthTypeField:
case storageIdentityClientIDField:
storageIdentityClientID = v
case storageIdentityObjectIDField:
case storageIdentityResourceIDField:
case msiEndpointField:
Expand Down Expand Up @@ -418,7 +417,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
return nil, status.Errorf(codes.Internal, "failed to GetStorageAccesskey on account(%s) rg(%s), error: %v", accountOptions.Name, accountOptions.ResourceGroup, err)
}
}
if err := d.copyVolume(ctx, req, accountKey, validContainerName, storageEndpointSuffix, storageIdentityClientID); err != nil {
if err := d.copyVolume(ctx, req, accountKey, validContainerName, storageEndpointSuffix); err != nil {
return nil, err
}
} else {
Expand Down Expand Up @@ -713,7 +712,7 @@ func (d *Driver) DeleteBlobContainer(ctx context.Context, subsID, resourceGroupN
}

// CopyBlobContainer copies a blob container in the same storage account
func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeRequest, accountKey, dstContainerName, storageEndpointSuffix, storageIdentityClientID string) error {
func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeRequest, accountKey, dstContainerName, storageEndpointSuffix string) error {
var sourceVolumeID string
if req.GetVolumeContentSource() != nil && req.GetVolumeContentSource().GetVolume() != nil {
sourceVolumeID = req.GetVolumeContentSource().GetVolume().GetVolumeId()
Expand All @@ -728,7 +727,12 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque
}

var accountSasToken string
if storageIdentityClientID == "" {
if d.cloud.Config.AzureAuthConfig.UseFederatedWorkloadIdentityExtension || d.cloud.Config.AzureAuthConfig.UseManagedIdentityExtension {
err = d.authorizeAzcopyBySecurityPrincipal()
if err != nil {
return err
}
} else {
klog.V(2).Infof("generate sas token for account(%s)", accountName)
accountSasToken, err = generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes)
if err != nil {
Expand Down Expand Up @@ -757,17 +761,6 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque
return err
case util.AzcopyJobNotFound:
klog.V(2).Infof("copy blob container %s to %s", srcContainerName, dstContainerName)
if storageIdentityClientID != "" {
klog.V(2).Infof("use msi client id to authorize azcopy")
_, err = exec.Command("export", "AZCOPY_AUTO_LOGIN_TYPE=MSI").CombinedOutput()
if err != nil {
return err
}
_, err = exec.Command("export", fmt.Sprintf("AZCOPY_MSI_CLIENT_ID=%s", storageIdentityClientID)).CombinedOutput()
if err != nil {
return err
}
}
out, copyErr := exec.Command("azcopy", "copy", srcPath, dstPath, "--recursive", "--check-length=false").CombinedOutput()
if copyErr != nil {
klog.Warningf("CopyBlobContainer(%s, %s, %s) failed with error(%v): %v", resourceGroupName, accountName, dstPath, copyErr, string(out))
Expand All @@ -783,18 +776,65 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque
}

// copyVolume copies a volume form volume or snapshot, snapshot is not supported now
func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, accountKey, dstContainerName, storageEndpointSuffix, storageIdentityClientID string) error {
func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, accountKey, dstContainerName, storageEndpointSuffix string) error {
vs := req.VolumeContentSource
switch vs.Type.(type) {
case *csi.VolumeContentSource_Snapshot:
return status.Errorf(codes.InvalidArgument, "copy volume from volumeSnapshot is not supported")
case *csi.VolumeContentSource_Volume:
return d.copyBlobContainer(ctx, req, accountKey, dstContainerName, storageEndpointSuffix, storageIdentityClientID)
return d.copyBlobContainer(ctx, req, accountKey, dstContainerName, storageEndpointSuffix)
default:
return status.Errorf(codes.InvalidArgument, "%v is not a proper volume source", vs)
}
}

func (d *Driver) authorizeAzcopyBySecurityPrincipal() error {
if d.cloud.Config.AzureAuthConfig.UseFederatedWorkloadIdentityExtension {
klog.V(2).Infof("use service principal to authorize azcopy")
os.Setenv("AZCOPY_AUTO_LOGIN_TYPE", "SPN")
if err := os.Setenv("AZCOPY_AUTO_LOGIN_TYPE", "SPN"); err != nil {
klog.Errorf("failed to set AZCOPY_AUTO_LOGIN_TYPE=SPN, error: %v", err)
return err
}
if d.cloud.Config.AADClientID == "" || d.cloud.Config.AADClientSecret == "" || d.cloud.Config.TenantID == "" {
return fmt.Errorf("AADClientID: %s, AADClientSecret: %s and TenantID: %s must be set when use federated workload identity extension", d.cloud.Config.AADClientID, d.cloud.Config.AADClientSecret, d.cloud.Config.TenantID)
}
if err := os.Setenv("AZCOPY_SPA_APPLICATION_ID", d.cloud.Config.AADClientID); err != nil {
klog.Errorf("failed to set AZCOPY_SPA_APPLICATION_ID, error: %v", err)
return err
}
if err := os.Setenv("AZCOPY_SPA_CLIENT_SECRET", d.cloud.Config.AADClientSecret); err != nil {
klog.Errorf("failed to set AZCOPY_SPA_CLIENT_SECRET, error: %v", err)
return err
}
if err := os.Setenv("AZCOPY_TENANT_ID", d.cloud.Config.TenantID); err != nil {
klog.Errorf("failed to set AZCOPY_TENANT_ID, error: %v", err)
return err
}
klog.V(2).Infof("set AZCOPY_AUTO_LOGIN_TYPE=SPN, AZCOPY_SPA_APPLICATION_ID, AZCOPY_SPA_CLIENT_SECRET and AZCOPY_TENANT_ID successfully")
return nil
}
if d.cloud.Config.AzureAuthConfig.UseManagedIdentityExtension {
klog.V(2).Infof("use managed identity to authorize azcopy")
if err := os.Setenv("AZCOPY_AUTO_LOGIN_TYPE", "MSI"); err != nil {
klog.Errorf("failed to set AZCOPY_AUTO_LOGIN_TYPE=MSI, error: %v", err)
}
klog.V(2).Infof("set AZCOPY_AUTO_LOGIN_TYPE=MSI successfully")
if len(d.cloud.Config.UserAssignedIdentityID) > 0 {
klog.V(2).Infof("authorize by using a user-assigned managed identity")
if err := os.Setenv("AZCOPY_MSI_CLIENT_ID", d.cloud.Config.UserAssignedIdentityID); err != nil {
klog.Errorf("failed to set AZCOPY_MSI_CLIENT_ID, error: %v", err)
return err
}
klog.V(2).Infof("set AZCOPY_MSI_CLIENT_ID successfully")
return nil
}
klog.V(2).Infof("authorize by using a system-assigned managed identity")
return nil
}
return fmt.Errorf("useFederatedWorkloadIdentityExtension or useManagedIdentityExtension must be set to true")
}

// isValidVolumeCapabilities validates the given VolumeCapability array is valid
func isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) error {
if len(volCaps) == 0 {
Expand Down
12 changes: 6 additions & 6 deletions pkg/blob/controllerserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ func TestCopyVolume(t *testing.T) {
ctx := context.Background()

expectedErr := status.Errorf(codes.InvalidArgument, "copy volume from volumeSnapshot is not supported")
err := d.copyVolume(ctx, req, "", "", "core.windows.net", "")
err := d.copyVolume(ctx, req, "", "", "core.windows.net")
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
Expand Down Expand Up @@ -1528,7 +1528,7 @@ func TestCopyVolume(t *testing.T) {
ctx := context.Background()

expectedErr := status.Errorf(codes.NotFound, "error parsing volume id: \"unit-test\", should at least contain two #")
err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net", "")
err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net")
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
Expand Down Expand Up @@ -1560,7 +1560,7 @@ func TestCopyVolume(t *testing.T) {
ctx := context.Background()

expectedErr := fmt.Errorf("srcContainerName() or dstContainerName(dstContainer) is empty")
err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net", "")
err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net")
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
Expand Down Expand Up @@ -1592,7 +1592,7 @@ func TestCopyVolume(t *testing.T) {
ctx := context.Background()

expectedErr := fmt.Errorf("srcContainerName(fileshare) or dstContainerName() is empty")
err := d.copyVolume(ctx, req, "", "", "core.windows.net", "")
err := d.copyVolume(ctx, req, "", "", "core.windows.net")
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
Expand Down Expand Up @@ -1636,7 +1636,7 @@ func TestCopyVolume(t *testing.T) {
ctx := context.Background()

var expectedErr error
err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net", "")
err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net")
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
Expand Down Expand Up @@ -1681,7 +1681,7 @@ func TestCopyVolume(t *testing.T) {
ctx := context.Background()

var expectedErr error
err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net", "")
err := d.copyVolume(ctx, req, "", "dstContainer", "core.windows.net")
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
Expand Down

0 comments on commit f789524

Please sign in to comment.