Skip to content

Commit

Permalink
use cluster identity for azcopy
Browse files Browse the repository at this point in the history
  • Loading branch information
umagnus committed Dec 15, 2023
1 parent ded1639 commit a45cd1d
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 4 deletions.
72 changes: 68 additions & 4 deletions pkg/blob/controllerserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"net/url"
"os"
"os/exec"
"strconv"
"strings"
Expand Down Expand Up @@ -50,6 +51,12 @@ import (
const (
privateEndpoint = "privateendpoint"

azcopyAutoLoginType = "AZCOPY_AUTO_LOGIN_TYPE"
azcopySPAApplicationID = "AZCOPY_SPA_APPLICATION_ID"
azcopySPAClientSecret = "AZCOPY_SPA_CLIENT_SECRET"
azcopyTenantID = "AZCOPY_TENANT_ID"
azcopyMSIClientID = "AZCOPY_MSI_CLIENT_ID"

waitForCopyInterval = 5 * time.Second
waitForCopyTimeout = 3 * time.Minute
)
Expand Down Expand Up @@ -726,10 +733,27 @@ func (d *Driver) copyBlobContainer(_ context.Context, req *csi.CreateVolumeReque
return fmt.Errorf("srcContainerName(%s) or dstContainerName(%s) is empty", srcContainerName, dstContainerName)
}

klog.V(2).Infof("generate sas token for account(%s)", accountName)
accountSasToken, genErr := generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes)
if genErr != nil {
return genErr
var accountSasToken string
secrets := req.GetSecrets()
if len(secrets) != 0 {
if _, ok := secrets[accountSasTokenField]; ok {
accountSasToken = secrets[accountSasTokenField]
klog.V(2).Infof("get sas token from secret for copy volume successfully")
}
}
if accountSasToken == "" {
if len(d.cloud.Config.AzureAuthConfig.AADClientSecret) > 0 || d.cloud.Config.AzureAuthConfig.UseManagedIdentityExtension {
err = d.authorizeAzcopyBySecurityPrincipal()
if err != nil {
return err
}
} else {
klog.V(2).Infof("generate sas token for account(%s)", accountName)
accountSasToken, err = generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes)
if err != nil {
return err
}
}
}

timeAfter := time.After(waitForCopyTimeout)
Expand Down Expand Up @@ -780,6 +804,46 @@ func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, a
}
}

func (d *Driver) authorizeAzcopyBySecurityPrincipal() error {
if len(d.cloud.Config.AzureAuthConfig.AADClientSecret) > 0 {
klog.V(2).Infof("use service principal to authorize azcopy")
if err := os.Setenv(azcopyAutoLoginType, "SPN"); err != nil {
return err
}
if d.cloud.Config.AzureAuthConfig.AADClientID == "" {
return fmt.Errorf("AADClientID and AADClientSecret must be set when use service principal")
}
if err := os.Setenv(azcopySPAApplicationID, d.cloud.Config.AzureAuthConfig.AADClientID); err != nil {
return err
}
if err := os.Setenv(azcopySPAClientSecret, d.cloud.Config.AzureAuthConfig.AADClientSecret); err != nil {
return err
}
if d.cloud.Config.AzureAuthConfig.TenantID != "" {
if err := os.Setenv(azcopyTenantID, d.cloud.Config.AzureAuthConfig.TenantID); err != nil {
return err
}
klog.V(2).Infof(fmt.Sprintf("set AZCOPY_TENANT_ID=%s successfully", d.cloud.Config.AzureAuthConfig.TenantID))
}
return nil
}
if d.cloud.Config.AzureAuthConfig.UseManagedIdentityExtension {
if err := os.Setenv(azcopyAutoLoginType, "MSI"); err != nil {
return err
}
if len(d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID) > 0 {
klog.V(2).Infof("use user assigned managed identity to authorize azcopy")
if err := os.Setenv(azcopyMSIClientID, d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID); err != nil {

Check warning on line 836 in pkg/blob/controllerserver.go

View workflow job for this annotation

GitHub Actions / Go Lint

if-return: redundant if ...; err != nil check, just return error instead. (revive)
return err
}
return nil
}
klog.V(2).Infof("use system-assigned managed identity to authorize azcopy")
return nil
}
return fmt.Errorf("AADClientSecret shouldn't be nil or useManagedIdentityExtension must be set to true")
}

// isValidVolumeCapabilities validates the given VolumeCapability array is valid
func isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) error {
if len(volCaps) == 0 {
Expand Down
118 changes: 118 additions & 0 deletions pkg/blob/controllerserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ import (
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/utils/pointer"
"sigs.k8s.io/blob-csi-driver/pkg/util"
"sigs.k8s.io/cloud-provider-azure/pkg/azclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/blobclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/storageaccountclient/mockstorageaccountclient"
azure "sigs.k8s.io/cloud-provider-azure/pkg/provider"
"sigs.k8s.io/cloud-provider-azure/pkg/provider/config"
"sigs.k8s.io/cloud-provider-azure/pkg/retry"
)

Expand Down Expand Up @@ -1780,3 +1782,119 @@ func Test_generateSASToken(t *testing.T) {
})
}
}

func Test_authorizeAzcopyBySecurityPrincipal(t *testing.T) {
testCases := []struct {
name string
testFunc func(t *testing.T)
}{
{
name: "use service principal to authorize azcopy",
testFunc: func(t *testing.T) {
d := NewFakeDriver()
d.cloud = &azure.Cloud{
Config: azure.Config{
AzureAuthConfig: config.AzureAuthConfig{
ARMClientConfig: azclient.ARMClientConfig{
TenantID: "TenantID",
},
AzureAuthConfig: azclient.AzureAuthConfig{
AADClientID: "AADClientID",
AADClientSecret: "AADClientSecret",
},
},
},
}
var expectedErr error
err := d.authorizeAzcopyBySecurityPrincipal()
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
},
},
{
name: "use service principal to authorize azcopy but client id is empty",
testFunc: func(t *testing.T) {
d := NewFakeDriver()
d.cloud = &azure.Cloud{
Config: azure.Config{
AzureAuthConfig: config.AzureAuthConfig{
ARMClientConfig: azclient.ARMClientConfig{
TenantID: "TenantID",
},
AzureAuthConfig: azclient.AzureAuthConfig{
AADClientSecret: "AADClientSecret",
},
},
},
}
expectedErr := fmt.Errorf("AADClientID and AADClientSecret must be set when use service principal")
err := d.authorizeAzcopyBySecurityPrincipal()
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
},
},
{
name: "use user assigned managed identity to authorize azcopy",
testFunc: func(t *testing.T) {
d := NewFakeDriver()
d.cloud = &azure.Cloud{
Config: azure.Config{
AzureAuthConfig: config.AzureAuthConfig{
AzureAuthConfig: azclient.AzureAuthConfig{
UseManagedIdentityExtension: true,
UserAssignedIdentityID: "UserAssignedIdentityID",
},
},
},
}
var expectedErr error
err := d.authorizeAzcopyBySecurityPrincipal()
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
},
},
{
name: "use system assigned managed identity to authorize azcopy",
testFunc: func(t *testing.T) {
d := NewFakeDriver()
d.cloud = &azure.Cloud{
Config: azure.Config{
AzureAuthConfig: config.AzureAuthConfig{
AzureAuthConfig: azclient.AzureAuthConfig{
UseManagedIdentityExtension: true,
},
},
},
}
var expectedErr error
err := d.authorizeAzcopyBySecurityPrincipal()
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
},
},
{
name: "AADClientSecret be nil and useManagedIdentityExtension is false",
testFunc: func(t *testing.T) {
d := NewFakeDriver()
d.cloud = &azure.Cloud{
Config: azure.Config{
AzureAuthConfig: config.AzureAuthConfig{},
},
}
expectedErr := fmt.Errorf("AADClientSecret shouldn't be nil or useManagedIdentityExtension must be set to true")
err := d.authorizeAzcopyBySecurityPrincipal()
if !reflect.DeepEqual(err, expectedErr) {
t.Errorf("Unexpected error: %v", err)
}
},
},
}

for _, tc := range testCases {
t.Run(tc.name, tc.testFunc)
}
}

0 comments on commit a45cd1d

Please sign in to comment.