diff --git a/hack/update-mock.sh b/hack/update-mock.sh old mode 100644 new mode 100755 diff --git a/pkg/azurefile/azurefile.go b/pkg/azurefile/azurefile.go index 49a0f1f6b7..59cf6eafd6 100644 --- a/pkg/azurefile/azurefile.go +++ b/pkg/azurefile/azurefile.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "net/url" + "os" "os/exec" "strconv" "strings" @@ -196,6 +197,8 @@ var ( supportedFSGroupChangePolicyList = []string{FSGroupChangeNone, string(v1.FSGroupChangeAlways), string(v1.FSGroupChangeOnRootMismatch)} retriableErrors = []string{accountNotProvisioned, tooManyRequests, shareBeingDeleted, clientThrottled} + + defaultAzcopyCopyOptions = []string{"--recursive", "--check-length=false"} ) // Driver implements all interfaces of CSI drivers @@ -250,6 +253,8 @@ type Driver struct { resizeFileShareFailureCache azcache.Resource // a timed cache storing volume stats volStatsCache azcache.Resource + // a timed cache storing account which should use sastoken for azcopy based volume cloning + azcopySasTokenCache azcache.Resource // sas expiry time for azcopy in volume clone sasTokenExpirationMinutes int // azcopy for provide exec mock for ut @@ -320,6 +325,10 @@ func NewDriver(options *DriverOptions) *Driver { klog.Fatalf("%v", err) } + if driver.azcopySasTokenCache, err = azcache.NewTimedCache(15*time.Minute, getter, false); err != nil { + klog.Fatalf("%v", err) + } + if driver.resizeFileShareFailureCache, err = azcache.NewTimedCache(3*time.Minute, getter, false); err != nil { klog.Fatalf("%v", err) } @@ -977,7 +986,7 @@ func (d *Driver) ResizeFileShare(ctx context.Context, subsID, resourceGroup, acc } // copyFileShare copies a fileshare in the same storage account -func (d *Driver) copyFileShare(req *csi.CreateVolumeRequest, accountKey string, shareOptions *fileclient.ShareOptions, storageEndpointSuffix string) error { +func (d *Driver) copyFileShare(ctx context.Context, req *csi.CreateVolumeRequest, accountSASToken string, authAzcopyEnv []string, secretName, secretNamespace string, secrets map[string]string, shareOptions *fileclient.ShareOptions, accountOptions *azure.AccountOptions, storageEndpointSuffix string) error { if shareOptions.Protocol == storage.EnabledProtocolsNFS { return fmt.Errorf("protocol nfs is not supported for volume cloning") } @@ -994,18 +1003,12 @@ func (d *Driver) copyFileShare(req *csi.CreateVolumeRequest, accountKey string, return fmt.Errorf("srcFileShareName(%s) or dstFileShareName(%s) is empty", srcFileShareName, dstFileShareName) } - klog.V(2).Infof("generate sas token for account(%s)", accountName) - accountSasToken, genErr := generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes) - if genErr != nil { - return genErr - } - timeAfter := time.After(waitForCopyTimeout) timeTick := time.Tick(waitForCopyInterval) - srcPath := fmt.Sprintf("https://%s.file.%s/%s%s", accountName, storageEndpointSuffix, srcFileShareName, accountSasToken) - dstPath := fmt.Sprintf("https://%s.file.%s/%s%s", accountName, storageEndpointSuffix, dstFileShareName, accountSasToken) + srcPath := fmt.Sprintf("https://%s.file.%s/%s%s", accountName, storageEndpointSuffix, srcFileShareName, accountSASToken) + dstPath := fmt.Sprintf("https://%s.file.%s/%s%s", accountName, storageEndpointSuffix, dstFileShareName, accountSASToken) - jobState, percent, err := d.azcopy.GetAzcopyJob(dstFileShareName) + jobState, percent, err := d.azcopy.GetAzcopyJob(dstFileShareName, authAzcopyEnv) klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err) if jobState == fileutil.AzcopyJobError || jobState == fileutil.AzcopyJobCompleted { return err @@ -1014,14 +1017,30 @@ func (d *Driver) copyFileShare(req *csi.CreateVolumeRequest, accountKey string, for { select { case <-timeTick: - jobState, percent, err := d.azcopy.GetAzcopyJob(dstFileShareName) + jobState, percent, err := d.azcopy.GetAzcopyJob(dstFileShareName, authAzcopyEnv) klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err) switch jobState { case fileutil.AzcopyJobError, fileutil.AzcopyJobCompleted: return err case fileutil.AzcopyJobNotFound: klog.V(2).Infof("copy fileshare %s to %s", srcFileShareName, dstFileShareName) - out, copyErr := exec.Command("azcopy", "copy", srcPath, dstPath, "--recursive", "--check-length=false").CombinedOutput() + cmd := exec.Command("azcopy", "copy", srcPath, dstPath) + cmd.Args = append(cmd.Args, defaultAzcopyCopyOptions...) + if len(authAzcopyEnv) > 0 { + cmd.Env = append(os.Environ(), authAzcopyEnv...) + } + out, copyErr := cmd.CombinedOutput() + if accountSASToken == "" && strings.Contains(string(out), authorizationPermissionMismatch) && copyErr != nil { + klog.Warningf("azcopy list failed with AuthorizationPermissionMismatch error, should assign \"Storage File Data SMB Share Elevated Contributor\" role to controller identity, fall back to use sas token, original output: %v", string(out)) + d.azcopySasTokenCache.Set(accountName, "") + var sasToken string + if sasToken, _, err = d.getAzcopyAuth(ctx, accountName, "", storageEndpointSuffix, accountOptions, secrets, secretName, secretNamespace, true); err != nil { + return err + } + cmd := exec.Command("azcopy", "copy", srcPath+sasToken, dstPath+sasToken) + cmd.Args = append(cmd.Args, defaultAzcopyCopyOptions...) + out, copyErr = cmd.CombinedOutput() + } if copyErr != nil { klog.Warningf("CopyFileShare(%s, %s, %s) failed with error(%v): %v", resourceGroupName, accountName, dstFileShareName, copyErr, string(out)) } else { diff --git a/pkg/azurefile/controllerserver.go b/pkg/azurefile/controllerserver.go index b375d33cc3..e2d0f5c773 100644 --- a/pkg/azurefile/controllerserver.go +++ b/pkg/azurefile/controllerserver.go @@ -53,6 +53,15 @@ const ( privateEndpoint = "privateendpoint" snapshotTimeFormat = "2006-01-02T15:04:05.0000000Z07:00" snapshotsExpand = "snapshots" + + azcopyAutoLoginType = "AZCOPY_AUTO_LOGIN_TYPE" + azcopySPAApplicationID = "AZCOPY_SPA_APPLICATION_ID" + azcopySPAClientSecret = "AZCOPY_SPA_CLIENT_SECRET" + azcopyTenantID = "AZCOPY_TENANT_ID" + azcopyMSIClientID = "AZCOPY_MSI_CLIENT_ID" + MSI = "MSI" + SPN = "SPN" + authorizationPermissionMismatch = "AuthorizationPermissionMismatch" ) var ( @@ -108,7 +117,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) if acquired := d.volumeLocks.TryAcquire(volName); !acquired { // logging the job status if it's volume cloning if req.GetVolumeContentSource() != nil { - jobState, percent, err := d.azcopy.GetAzcopyJob(volName) + jobState, percent, err := d.azcopy.GetAzcopyJob(volName, []string{}) klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err) } return nil, status.Errorf(codes.Aborted, volumeOperationAlreadyExistsFmt, volName) @@ -571,11 +580,11 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) return nil, status.Errorf(codes.Internal, "failed to create file share(%s) on account(%s) type(%s) subsID(%s) rg(%s) location(%s) size(%d), error: %v", validFileShareName, account, sku, subsID, resourceGroup, location, fileShareSize, err) } if req.GetVolumeContentSource() != nil { - accountKeyCopy, err := d.GetStorageAccesskey(ctx, accountOptions, req.GetSecrets(), secretName, secretNamespace) + accountSASToken, authAzcopyEnv, err := d.getAzcopyAuth(ctx, accountName, accountKey, storageEndpointSuffix, accountOptions, secret, secretName, secretNamespace, false) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to GetStorageAccesskey on account(%s) rg(%s), error: %v", accountOptions.Name, accountOptions.ResourceGroup, err) + return nil, status.Errorf(codes.Internal, "failed to getAzcopyAuth on account(%s) rg(%s), error: %v", accountOptions.Name, accountOptions.ResourceGroup, err) } - if err := d.copyVolume(req, accountKeyCopy, shareOptions, storageEndpointSuffix); err != nil { + if err := d.copyVolume(ctx, req, accountSASToken, authAzcopyEnv, secretName, secretNamespace, secret, shareOptions, accountOptions, storageEndpointSuffix); err != nil { return nil, err } // storeAccountKey is not needed here since copy volume is only using SAS token @@ -726,13 +735,13 @@ func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) } // copyVolume copy an azure file -func (d *Driver) copyVolume(req *csi.CreateVolumeRequest, accountKey string, shareOptions *fileclient.ShareOptions, storageEndpointSuffix string) error { +func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, accountSASToken string, authAzcopyEnv []string, secretName, secretNamespace string, secrets map[string]string, shareOptions *fileclient.ShareOptions, accountOptions *azure.AccountOptions, storageEndpointSuffix string) error { vs := req.VolumeContentSource switch vs.Type.(type) { case *csi.VolumeContentSource_Snapshot: return status.Errorf(codes.InvalidArgument, "copy volume from volumeSnapshot is not supported") case *csi.VolumeContentSource_Volume: - return d.copyFileShare(req, accountKey, shareOptions, storageEndpointSuffix) + return d.copyFileShare(ctx, req, accountSASToken, authAzcopyEnv, secretName, secretNamespace, secrets, shareOptions, accountOptions, storageEndpointSuffix) default: return status.Errorf(codes.InvalidArgument, "%v is not a proper volume source", vs) } @@ -1300,6 +1309,77 @@ func isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) error { return nil } +func (d *Driver) authorizeAzcopyWithIdentity() ([]string, error) { + azureAuthConfig := d.cloud.Config.AzureAuthConfig + var authAzcopyEnv []string + if azureAuthConfig.UseManagedIdentityExtension { + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopyAutoLoginType, MSI)) + if len(azureAuthConfig.UserAssignedIdentityID) > 0 { + klog.V(2).Infof("use user assigned managed identity to authorize azcopy") + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopyMSIClientID, azureAuthConfig.UserAssignedIdentityID)) + } else { + klog.V(2).Infof("use system-assigned managed identity to authorize azcopy") + } + return authAzcopyEnv, nil + } + if len(azureAuthConfig.AADClientSecret) > 0 { + klog.V(2).Infof("use service principal to authorize azcopy") + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopyAutoLoginType, SPN)) + if azureAuthConfig.AADClientID == "" || azureAuthConfig.TenantID == "" { + return []string{}, fmt.Errorf("AADClientID and TenantID must be set when use service principal") + } + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopySPAApplicationID, azureAuthConfig.AADClientID)) + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopySPAClientSecret, azureAuthConfig.AADClientSecret)) + authAzcopyEnv = append(authAzcopyEnv, fmt.Sprintf("%s=%s", azcopyTenantID, azureAuthConfig.TenantID)) + klog.V(2).Infof(fmt.Sprintf("set AZCOPY_SPA_APPLICATION_ID=%s, AZCOPY_TENANT_ID=%s successfully", azureAuthConfig.AADClientID, azureAuthConfig.TenantID)) + + return authAzcopyEnv, nil + } + return []string{}, fmt.Errorf("neither the service principal nor the managed identity has been set") +} + +// getAzcopyAuth will only generate sas token for azcopy in following conditions: +// 1. secrets is not empty +// 2. driver is not using managed identity and service principal +// 3. azcopy returns AuthorizationPermissionMismatch error when using service principal or managed identity +// 4. parameter useSasToken is true +func (d *Driver) getAzcopyAuth(ctx context.Context, accountName, accountKey, storageEndpointSuffix string, accountOptions *azure.AccountOptions, secrets map[string]string, secretName, secretNamespace string, useSasToken bool) (string, []string, error) { + var authAzcopyEnv []string + if !useSasToken && len(secrets) == 0 && len(secretName) == 0 { + var err error + authAzcopyEnv, err = d.authorizeAzcopyWithIdentity() + if err != nil { + klog.Warningf("failed to authorize azcopy with identity, error: %v", err) + } else { + if len(authAzcopyEnv) > 0 { + // search in cache first + cache, err := d.azcopySasTokenCache.Get(accountName, azcache.CacheReadTypeDefault) + if err != nil { + return "", nil, fmt.Errorf("get(%s) from azcopySasTokenCache failed with error: %v", accountName, err) + } + if cache != nil { + klog.V(2).Infof("use sas token for account(%s) since this account is found in azcopySasTokenCache", accountName) + useSasToken = true + } + } + } + } + + if len(secrets) > 0 || len(secretName) > 0 || len(authAzcopyEnv) == 0 || useSasToken { + var err error + if accountKey == "" { + if accountKey, err = d.GetStorageAccesskey(ctx, accountOptions, secrets, secretName, secretNamespace); err != nil { + return "", nil, err + } + } + klog.V(2).Infof("generate sas token for account(%s)", accountName) + sasToken, err := generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes) + return sasToken, nil, err + } + return "", authAzcopyEnv, nil +} + +// generateSASToken generate a sas token for storage account func generateSASToken(accountName, accountKey, storageEndpointSuffix string, expiryTime int) (string, error) { credential, err := service.NewSharedKeyCredential(accountName, accountKey) if err != nil { diff --git a/pkg/azurefile/controllerserver_test.go b/pkg/azurefile/controllerserver_test.go index 943e92a46b..7057910a01 100644 --- a/pkg/azurefile/controllerserver_test.go +++ b/pkg/azurefile/controllerserver_test.go @@ -29,7 +29,9 @@ import ( "time" "sigs.k8s.io/azurefile-csi-driver/pkg/util" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/fileclient" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient/mocksubnetclient" @@ -1645,10 +1647,13 @@ func TestCopyVolume(t *testing.T) { VolumeContentSource: &volumecontensource, } + secret := map[string]string{} + d := NewFakeDriver() + ctx := context.Background() expectedErr := status.Errorf(codes.InvalidArgument, "copy volume from volumeSnapshot is not supported") - err := d.copyVolume(req, "", nil, "core.windows.net") + err := d.copyVolume(ctx, req, "", []string{}, "", "", secret, nil, nil, "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1677,10 +1682,13 @@ func TestCopyVolume(t *testing.T) { VolumeContentSource: &volumecontensource, } + secret := map[string]string{} + d := NewFakeDriver() + ctx := context.Background() expectedErr := fmt.Errorf("protocol nfs is not supported for volume cloning") - err := d.copyVolume(req, "", &fileclient.ShareOptions{Protocol: storage.EnabledProtocolsNFS}, "core.windows.net") + err := d.copyVolume(ctx, req, "", []string{}, "", "", secret, &fileclient.ShareOptions{Protocol: storage.EnabledProtocolsNFS}, nil, "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1709,10 +1717,13 @@ func TestCopyVolume(t *testing.T) { VolumeContentSource: &volumecontensource, } + secret := map[string]string{} + d := NewFakeDriver() + ctx := context.Background() expectedErr := status.Errorf(codes.NotFound, "error parsing volume id: \"unit-test\", should at least contain two #") - err := d.copyVolume(req, "", &fileclient.ShareOptions{Name: "dstFileshare"}, "core.windows.net") + err := d.copyVolume(ctx, req, "", []string{}, "", "", secret, &fileclient.ShareOptions{Name: "dstFileshare"}, nil, "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1741,10 +1752,13 @@ func TestCopyVolume(t *testing.T) { VolumeContentSource: &volumecontensource, } + secret := map[string]string{} + d := NewFakeDriver() + ctx := context.Background() expectedErr := fmt.Errorf("srcFileShareName() or dstFileShareName(dstFileshare) is empty") - err := d.copyVolume(req, "", &fileclient.ShareOptions{Name: "dstFileshare"}, "core.windows.net") + err := d.copyVolume(ctx, req, "", []string{}, "", "", secret, &fileclient.ShareOptions{Name: "dstFileshare"}, nil, "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1773,10 +1787,13 @@ func TestCopyVolume(t *testing.T) { VolumeContentSource: &volumecontensource, } + secret := map[string]string{} + d := NewFakeDriver() + ctx := context.Background() expectedErr := fmt.Errorf("srcFileShareName(fileshare) or dstFileShareName() is empty") - err := d.copyVolume(req, "", &fileclient.ShareOptions{}, "core.windows.net") + err := d.copyVolume(ctx, req, "", []string{}, "", "", secret, &fileclient.ShareOptions{}, nil, "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1805,12 +1822,15 @@ func TestCopyVolume(t *testing.T) { VolumeContentSource: &volumecontensource, } + secret := map[string]string{} + ctx := context.Background() + ctrl := gomock.NewController(t) defer ctrl.Finish() m := util.NewMockEXEC(ctrl) listStr := "JobId: ed1c3833-eaff-fe42-71d7-513fb065a9d9\nStart Time: Monday, 07-Aug-23 03:29:54 UTC\nStatus: Completed\nCommand: copy https://{accountName}.file.core.windows.net/{srcFileshare}{SAStoken} https://{accountName}.file.core.windows.net/{dstFileshare}{SAStoken} --recursive --check-length=false" - m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstFileshare -B 3")).Return(listStr, nil) + m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstFileshare -B 3"), gomock.Any()).Return(listStr, nil) // if test.enableShow { // m.EXPECT().RunCommand(gomock.Not("azcopy jobs list | grep dstContainer -B 3")).Return(test.showStr, test.showErr) // } @@ -1818,7 +1838,7 @@ func TestCopyVolume(t *testing.T) { d.azcopy.ExecCmd = m var expectedErr error - err := d.copyVolume(req, "", &fileclient.ShareOptions{Name: "dstFileshare"}, "core.windows.net") + err := d.copyVolume(ctx, req, "sastoken", []string{}, "", "", secret, &fileclient.ShareOptions{Name: "dstFileshare"}, nil, "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -1846,6 +1866,8 @@ func TestCopyVolume(t *testing.T) { Parameters: mp, VolumeContentSource: &volumecontensource, } + secret := map[string]string{} + ctx := context.Background() ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -1853,15 +1875,15 @@ func TestCopyVolume(t *testing.T) { m := util.NewMockEXEC(ctrl) listStr1 := "JobId: ed1c3833-eaff-fe42-71d7-513fb065a9d9\nStart Time: Monday, 07-Aug-23 03:29:54 UTC\nStatus: InProgress\nCommand: copy https://{accountName}.file.core.windows.net/{srcFileshare}{SAStoken} https://{accountName}.file.core.windows.net/{dstFileshare}{SAStoken} --recursive --check-length=false" listStr2 := "JobId: ed1c3833-eaff-fe42-71d7-513fb065a9d9\nStart Time: Monday, 07-Aug-23 03:29:54 UTC\nStatus: Completed\nCommand: copy https://{accountName}.file.core.windows.net/{srcFileshare}{SAStoken} https://{accountName}.file.core.windows.net/{dstFileshare}{SAStoken} --recursive --check-length=false" - o1 := m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstFileshare -B 3")).Return(listStr1, nil).Times(1) - m.EXPECT().RunCommand(gomock.Not("azcopy jobs list | grep dstFileshare -B 3")).Return("Percent Complete (approx): 50.0", nil) - o2 := m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstFileshare -B 3")).Return(listStr2, nil) + o1 := m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstFileshare -B 3"), gomock.Any()).Return(listStr1, nil).Times(1) + m.EXPECT().RunCommand(gomock.Not("azcopy jobs list | grep dstFileshare -B 3"), gomock.Any()).Return("Percent Complete (approx): 50.0", nil) + o2 := m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstFileshare -B 3"), gomock.Any()).Return(listStr2, nil) gomock.InOrder(o1, o2) d.azcopy.ExecCmd = m var expectedErr error - err := d.copyVolume(req, "", &fileclient.ShareOptions{Name: "dstFileshare"}, "core.windows.net") + err := d.copyVolume(ctx, req, "sastoken", []string{}, "", "", secret, &fileclient.ShareOptions{Name: "dstFileshare"}, nil, "core.windows.net") if !reflect.DeepEqual(err, expectedErr) { t.Errorf("Unexpected error: %v", err) } @@ -2904,3 +2926,187 @@ func TestGenerateSASToken(t *testing.T) { }) } } + +func TestAuthorizeAzcopyWithIdentity(t *testing.T) { + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "use service principal to authorize azcopy", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{ + ARMClientConfig: azclient.ARMClientConfig{ + TenantID: "TenantID", + }, + AzureAuthConfig: azclient.AzureAuthConfig{ + AADClientID: "AADClientID", + AADClientSecret: "AADClientSecret", + }, + }, + }, + } + expectedAuthAzcopyEnv := []string{ + fmt.Sprintf(azcopyAutoLoginType + "=SPN"), + fmt.Sprintf(azcopySPAApplicationID + "=AADClientID"), + fmt.Sprintf(azcopySPAClientSecret + "=AADClientSecret"), + fmt.Sprintf(azcopyTenantID + "=TenantID"), + } + var expectedErr error + authAzcopyEnv, err := d.authorizeAzcopyWithIdentity() + if !reflect.DeepEqual(authAzcopyEnv, expectedAuthAzcopyEnv) || !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected authAzcopyEnv: %v, Unexpected error: %v", authAzcopyEnv, err) + } + }, + }, + { + name: "use service principal to authorize azcopy but client id is empty", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{ + ARMClientConfig: azclient.ARMClientConfig{ + TenantID: "TenantID", + }, + AzureAuthConfig: azclient.AzureAuthConfig{ + AADClientSecret: "AADClientSecret", + }, + }, + }, + } + expectedAuthAzcopyEnv := []string{} + expectedErr := fmt.Errorf("AADClientID and TenantID must be set when use service principal") + authAzcopyEnv, err := d.authorizeAzcopyWithIdentity() + if !reflect.DeepEqual(authAzcopyEnv, expectedAuthAzcopyEnv) || !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected authAzcopyEnv: %v, Unexpected error: %v", authAzcopyEnv, err) + } + }, + }, + { + name: "use user assigned managed identity to authorize azcopy", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{ + AzureAuthConfig: azclient.AzureAuthConfig{ + UseManagedIdentityExtension: true, + UserAssignedIdentityID: "UserAssignedIdentityID", + }, + }, + }, + } + expectedAuthAzcopyEnv := []string{ + fmt.Sprintf(azcopyAutoLoginType + "=MSI"), + fmt.Sprintf(azcopyMSIClientID + "=UserAssignedIdentityID"), + } + var expectedErr error + authAzcopyEnv, err := d.authorizeAzcopyWithIdentity() + if !reflect.DeepEqual(authAzcopyEnv, expectedAuthAzcopyEnv) || !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected authAzcopyEnv: %v, Unexpected error: %v", authAzcopyEnv, err) + } + }, + }, + { + name: "use system assigned managed identity to authorize azcopy", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{ + AzureAuthConfig: azclient.AzureAuthConfig{ + UseManagedIdentityExtension: true, + }, + }, + }, + } + expectedAuthAzcopyEnv := []string{ + fmt.Sprintf(azcopyAutoLoginType + "=MSI"), + } + var expectedErr error + authAzcopyEnv, err := d.authorizeAzcopyWithIdentity() + if !reflect.DeepEqual(authAzcopyEnv, expectedAuthAzcopyEnv) || !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected authAzcopyEnv: %v, Unexpected error: %v", authAzcopyEnv, err) + } + }, + }, + { + name: "AADClientSecret be nil and useManagedIdentityExtension is false", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{ + AzureAuthConfig: config.AzureAuthConfig{}, + }, + } + expectedAuthAzcopyEnv := []string{} + expectedErr := fmt.Errorf("neither the service principal nor the managed identity has been set") + authAzcopyEnv, err := d.authorizeAzcopyWithIdentity() + if !reflect.DeepEqual(authAzcopyEnv, expectedAuthAzcopyEnv) || !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected authAzcopyEnv: %v, Unexpected error: %v", authAzcopyEnv, err) + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} + +func TestGetAzcopyAuth(t *testing.T) { + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "failed to get accountKey in secrets", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{}, + } + secrets := map[string]string{ + defaultSecretAccountName: "accountName", + } + + ctx := context.Background() + expectedAccountSASToken := "" + expectedErr := fmt.Errorf("could not find accountkey or azurestorageaccountkey field in secrets") + accountSASToken, authAzcopyEnv, err := d.getAzcopyAuth(ctx, "accountName", "", "core.windows.net", &azure.AccountOptions{}, secrets, "secretsName", "secretsNamespace", false) + if !reflect.DeepEqual(err, expectedErr) || authAzcopyEnv != nil || !reflect.DeepEqual(accountSASToken, expectedAccountSASToken) { + t.Errorf("Unexpected accountSASToken: %s, Unexpected authAzcopyEnv: %v, Unexpected error: %v", accountSASToken, authAzcopyEnv, err) + } + }, + }, + { + name: "generate SAS token failed for illegal account key", + testFunc: func(t *testing.T) { + d := NewFakeDriver() + d.cloud = &azure.Cloud{ + Config: azure.Config{}, + } + secrets := map[string]string{ + defaultSecretAccountName: "accountName", + defaultSecretAccountKey: "fakeValue", + } + + ctx := context.Background() + expectedAccountSASToken := "" + expectedErr := status.Errorf(codes.Internal, fmt.Sprintf("failed to generate sas token in creating new shared key credential, accountName: %s, err: %s", "accountName", "decode account key: illegal base64 data at input byte 8")) + accountSASToken, authAzcopyEnv, err := d.getAzcopyAuth(ctx, "accountName", "", "core.windows.net", &azure.AccountOptions{}, secrets, "secretsName", "secretsNamespace", false) + if !reflect.DeepEqual(err, expectedErr) || authAzcopyEnv != nil || !reflect.DeepEqual(accountSASToken, expectedAccountSASToken) { + t.Errorf("Unexpected accountSASToken: %s, Unexpected authAzcopyEnv: %v, Unexpected error: %v", accountSASToken, authAzcopyEnv, err) + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} diff --git a/pkg/util/util.go b/pkg/util/util.go index aaa4529779..fd8f29bad9 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -89,14 +89,18 @@ func RunPowershellCmd(command string, envs ...string) ([]byte, error) { } type EXEC interface { - RunCommand(string) (string, error) + RunCommand(string, []string) (string, error) } type ExecCommand struct { } -func (ec *ExecCommand) RunCommand(cmd string) (string, error) { - out, err := exec.Command("sh", "-c", cmd).CombinedOutput() +func (ec *ExecCommand) RunCommand(cmdStr string, authEnv []string) (string, error) { + cmd := exec.Command("sh", "-c", cmdStr) + if len(authEnv) > 0 { + cmd.Env = append(os.Environ(), authEnv...) + } + out, err := cmd.CombinedOutput() return string(out), err } @@ -105,7 +109,7 @@ type Azcopy struct { } // GetAzcopyJob get the azcopy job status if job existed -func (ac *Azcopy) GetAzcopyJob(dstFileshare string) (AzcopyJobState, string, error) { +func (ac *Azcopy) GetAzcopyJob(dstFileshare string, authAzcopyEnv []string) (AzcopyJobState, string, error) { cmdStr := fmt.Sprintf("azcopy jobs list | grep %s -B 3", dstFileshare) // cmd output example: // JobId: ed1c3833-eaff-fe42-71d7-513fb065a9d9 @@ -120,7 +124,7 @@ func (ac *Azcopy) GetAzcopyJob(dstFileshare string) (AzcopyJobState, string, err if ac.ExecCmd == nil { ac.ExecCmd = &ExecCommand{} } - out, err := ac.ExecCmd.RunCommand(cmdStr) + out, err := ac.ExecCmd.RunCommand(cmdStr, authAzcopyEnv) // if grep command returns nothing, the exec will return exit status 1 error, so filter this error if err != nil && err.Error() != "exit status 1" { klog.Warningf("failed to get azcopy job with error: %v, jobState: %v", err, AzcopyJobError) @@ -140,7 +144,7 @@ func (ac *Azcopy) GetAzcopyJob(dstFileshare string) (AzcopyJobState, string, err cmdPercentStr := fmt.Sprintf("azcopy jobs show %s | grep Percent", jobid) // cmd out example: // Percent Complete (approx): 100.0 - summary, err := ac.ExecCmd.RunCommand(cmdPercentStr) + summary, err := ac.ExecCmd.RunCommand(cmdPercentStr, authAzcopyEnv) if err != nil { klog.Warningf("failed to get azcopy job with error: %v, jobState: %v", err, AzcopyJobError) return AzcopyJobError, "", fmt.Errorf("couldn't show jobs summary in azcopy %v", err) @@ -153,6 +157,15 @@ func (ac *Azcopy) GetAzcopyJob(dstFileshare string) (AzcopyJobState, string, err return jobState, percent, nil } +// TestListJobs test azcopy jobs list command with authAzcopyEnv +func (ac *Azcopy) TestListJobs(accountName, storageEndpointSuffix string, authAzcopyEnv []string) (string, error) { + cmdStr := fmt.Sprintf("azcopy list %s", fmt.Sprintf("https://%s.file.%s", accountName, storageEndpointSuffix)) + if ac.ExecCmd == nil { + ac.ExecCmd = &ExecCommand{} + } + return ac.ExecCmd.RunCommand(cmdStr, authAzcopyEnv) +} + // parseAzcopyJobList parse command azcopy jobs list, get jobid and state from joblist func parseAzcopyJobList(joblist string) (string, AzcopyJobState, error) { jobid := "" diff --git a/pkg/util/util_mock.go b/pkg/util/util_mock.go index f381ec968d..6764d3fcf7 100644 --- a/pkg/util/util_mock.go +++ b/pkg/util/util_mock.go @@ -51,16 +51,16 @@ func (m *MockEXEC) EXPECT() *MockEXECMockRecorder { } // RunCommand mocks base method. -func (m *MockEXEC) RunCommand(arg0 string) (string, error) { +func (m *MockEXEC) RunCommand(arg0 string, arg1 []string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RunCommand", arg0) + ret := m.ctrl.Call(m, "RunCommand", arg0, arg1) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // RunCommand indicates an expected call of RunCommand. -func (mr *MockEXECMockRecorder) RunCommand(arg0 interface{}) *gomock.Call { +func (mr *MockEXECMockRecorder) RunCommand(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunCommand", reflect.TypeOf((*MockEXEC)(nil).RunCommand), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunCommand", reflect.TypeOf((*MockEXEC)(nil).RunCommand), arg0, arg1) } diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index 5d8af52686..aee57d9535 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -155,14 +155,14 @@ func TestGetAzcopyJob(t *testing.T) { defer ctrl.Finish() m := NewMockEXEC(ctrl) - m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstFileshare -B 3")).Return(test.listStr, test.listErr) + m.EXPECT().RunCommand(gomock.Eq("azcopy jobs list | grep dstFileshare -B 3"), []string{}).Return(test.listStr, test.listErr) if test.enableShow { - m.EXPECT().RunCommand(gomock.Not("azcopy jobs list | grep dstFileshare -B 3")).Return(test.showStr, test.showErr) + m.EXPECT().RunCommand(gomock.Not("azcopy jobs list | grep dstFileshare -B 3"), []string{}).Return(test.showStr, test.showErr) } azcopyFunc := &Azcopy{} azcopyFunc.ExecCmd = m - jobState, percent, err := azcopyFunc.GetAzcopyJob(dstFileshare) + jobState, percent, err := azcopyFunc.GetAzcopyJob(dstFileshare, []string{}) if jobState != test.expectedJobState || percent != test.expectedPercent || !reflect.DeepEqual(err, test.expectedErr) { t.Errorf("test[%s]: unexpected jobState: %v, percent: %v, err: %v, expected jobState: %v, percent: %v, err: %v", test.desc, jobState, percent, err, test.expectedJobState, test.expectedPercent, test.expectedErr) }