Skip to content

Commit

Permalink
[Service Connector] az spring-cloud connection create postgres: Add…
Browse files Browse the repository at this point in the history
… `--system-identity` for springcloud-postgres connection (#22459)

* connect by mi for postgresql and spring cloud

* connect by mi for postgresql and spring cloud

* update

* update

* update default auth type

* update

* set aad user

* fix lint

* update

* update error msg

refine output message

update

enable localcontext

update setup

update client type and output message

add dependencies alphabetically

update setup

update warning

add auto commit

update username to lower case

add new permissions

update command

switch to az spring

add dependencies

Revert "add dependencies"

This reverts commit 6ddbf9483a52e3c4169246033d0b8950bbf565e9.

update psycopg2

install dependency runtime

update to spring command

remove linux pg dependency

* support webapp

* update style

* support container app

* support postgres flex server

* refactor module

* refactor credential free

* lint and add sql

* fix name and improve psql flex

* support sql

* lint

* remove pyodbc dependency

* update psql query

* support mysql

* update sql suffix

* fix error message

* update postgresql connection

* update

* fix

* add retry for get identity api

* update

* update connection error message

* support mysql identity

* fix some bug

* remove client type restriction

* update sql authentication

* update

* lint

* update sql aad auth to azcli credential

* update psql query

* remove connection string log

* update package

* rename mysql-identity-id

* revert test

* lint

* solve comment

* update help

* fix errorkey

* pg_flex: enable uuid-ossp extension

* fix mysql grant query
  • Loading branch information
xfz11 authored Sep 19, 2022
1 parent 48bdcf1 commit 19b0ebe
Show file tree
Hide file tree
Showing 12 changed files with 1,030 additions and 57 deletions.

Large diffs are not rendered by default.

18 changes: 16 additions & 2 deletions src/azure-cli/azure/cli/command_modules/serviceconnector/_help.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def get_source_display_name(sourcename):
short-summary: The secret auth info
long-summary: |
Usage: --secret name=XX secret=XX
--secret name=XX secret-uri=XX
--secret name=XX secret-name=XX
name : Required. Username or account name for secret auth.
secret : One of <secret, secret-uri, secret-name> is required. Password or account key for secret auth.
Expand All @@ -235,13 +237,25 @@ def get_source_display_name(sourcename):
Usage: --secret
''' if AUTH_TYPE.SecretAuto in auth_types else ''
system_identity_param = '''
system_identity_param = ''
if AUTH_TYPE.SystemIdentity in auth_types:
if target in {RESOURCE.MysqlFlexible}:
system_identity_param = '''
- name: --system-identity
short-summary: The system assigned identity auth info
long-summary: |
Usage: --system-identity mysql-identity-id=xx
mysql-identity-id : Optional. ID of identity used for MySQL flexible server AAD Authentication. Ignore it if you are the server AAD administrator.
'''
else:
system_identity_param = '''
- name: --system-identity
short-summary: The system assigned identity auth info
long-summary: |
Usage: --system-identity
''' if AUTH_TYPE.SystemIdentity in auth_types else ''
'''
user_identity_param = '''
- name: --user-identity
short-summary: The user assigned identity auth info
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,11 @@ class CLIENT_TYPE(Enum):
# The first one will be used as the default auth type
SUPPORTED_AUTH_TYPE = {
RESOURCE.WebApp: {
RESOURCE.Postgres: [AUTH_TYPE.Secret],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret],
RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Mysql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret],
RESOURCE.Sql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Redis: [AUTH_TYPE.SecretAuto],
RESOURCE.RedisEnterprise: [AUTH_TYPE.SecretAuto],

Expand All @@ -679,11 +679,11 @@ class CLIENT_TYPE(Enum):
RESOURCE.ConfluentKafka: [AUTH_TYPE.Secret],
},
RESOURCE.SpringCloud: {
RESOURCE.Postgres: [AUTH_TYPE.Secret],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret],
RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Mysql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret],
RESOURCE.Sql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Redis: [AUTH_TYPE.SecretAuto],
RESOURCE.RedisEnterprise: [AUTH_TYPE.SecretAuto],

Expand Down Expand Up @@ -735,11 +735,11 @@ class CLIENT_TYPE(Enum):
RESOURCE.ConfluentKafka: [AUTH_TYPE.Secret],
},
RESOURCE.ContainerApp: {
RESOURCE.Postgres: [AUTH_TYPE.Secret],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret],
RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Mysql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret],
RESOURCE.Sql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Redis: [AUTH_TYPE.SecretAuto],
RESOURCE.RedisEnterprise: [AUTH_TYPE.SecretAuto],

Expand Down
89 changes: 54 additions & 35 deletions src/azure-cli/azure/cli/command_modules/serviceconnector/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@
# --------------------------------------------------------------------------------------------

import time
from knack.log import get_logger
from knack.util import todict
from msrestazure.tools import parse_resource_id
from azure.cli.core.azclierror import (
ValidationError,
CLIInternalError
)
from azure.cli.core._profile import Profile
from ._resource_config import (
SOURCE_RESOURCES_USERTOKEN,
TARGET_RESOURCES_USERTOKEN
TARGET_RESOURCES_USERTOKEN,
RESOURCE
)


logger = get_logger(__name__)


def should_load_source(source):
'''Check whether to load `az {source} connection`
If {source} is an extension (e.g, spring-cloud), load the command group only when {source} is installed
Expand Down Expand Up @@ -44,7 +50,8 @@ def generate_random_string(length=5, prefix='', lower_only=False, ensure_complex
import string

if lower_only and ensure_complexity:
raise CLIInternalError('lower_only and ensure_complexity can not both be specified to True')
raise CLIInternalError(
'lower_only and ensure_complexity can not both be specified to True')
if ensure_complexity and length < 8:
raise CLIInternalError('ensure_complexity needs length >= 8')

Expand All @@ -53,7 +60,8 @@ def generate_random_string(length=5, prefix='', lower_only=False, ensure_complex
character_set = string.ascii_lowercase

while True:
randstr = '{}{}'.format(prefix, ''.join(random.sample(character_set, length)))
randstr = '{}{}'.format(prefix, ''.join(
random.sample(character_set, length)))
lowers = [c for c in randstr if c.islower()]
uppers = [c for c in randstr if c.isupper()]
numbers = [c for c in randstr if c.isnumeric()]
Expand All @@ -63,36 +71,41 @@ def generate_random_string(length=5, prefix='', lower_only=False, ensure_complex
return randstr


def run_cli_cmd(cmd, retry=0):
def run_cli_cmd(cmd, retry=0, interval=0, should_retry_func=None):
'''Run a CLI command
:param cmd: The CLI command to be executed
:param retry: The times to re-try
:param interval: The seconds wait before retry
'''
import json
import subprocess

output = subprocess.run(cmd, shell=True, check=False, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
if output.returncode != 0:
output = subprocess.run(cmd, shell=True, check=False,
stderr=subprocess.PIPE, stdout=subprocess.PIPE)
logger.debug(output)
if output.returncode != 0 or (should_retry_func and should_retry_func(output)):
if retry:
run_cli_cmd(cmd, retry - 1)
else:
raise CLIInternalError('Command execution failed, command is: '
'{}, error message is: {}'.format(cmd, output.stderr))

return json.loads(output.stdout) if output.stdout else None
time.sleep(interval)
return run_cli_cmd(cmd, retry - 1, interval)
raise CLIInternalError('Command execution failed, command is: '
'{}, error message is: {}'.format(cmd, output.stderr))
try:
return json.loads(output.stdout) if output.stdout else None
except ValueError:
return output.stdout or None


def set_user_token_header(client, cli_ctx):
'''Set user token header to work around OBO'''
from azure.cli.core._profile import Profile

# pylint: disable=protected-access
# HACK: set custom header to work around OBO
profile = Profile(cli_ctx=cli_ctx)
creds, _, _ = profile.get_raw_token()
client._client._config.headers_policy._headers['x-ms-serviceconnector-user-token'] = creds[1]
# HACK: hide token header
client._config.logging_policy.headers_to_redact.append('x-ms-serviceconnector-user-token')
client._config.logging_policy.headers_to_redact.append(
'x-ms-serviceconnector-user-token')

return client

Expand All @@ -109,16 +122,14 @@ def provider_is_registered(subscription=None):
subs_arg = ''
if subscription:
subs_arg = '--subscription {}'.format(subscription)
output = run_cli_cmd('az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg))
output = run_cli_cmd(
'az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg))
if output.get('registrationState') == 'NotRegistered':
return False
return True


def register_provider(subscription=None):
from knack.log import get_logger
logger = get_logger(__name__)

logger.warning('Provider Microsoft.ServiceLinker is not registered, '
'trying to register. This usually takes 1-2 minutes.')

Expand All @@ -127,7 +138,8 @@ def register_provider(subscription=None):
subs_arg = '--subscription {}'.format(subscription)

# register the provider
run_cli_cmd('az provider register -n Microsoft.ServiceLinker {}'.format(subs_arg))
run_cli_cmd(
'az provider register -n Microsoft.ServiceLinker {}'.format(subs_arg))

# verify the registration, 30 * 10s polling the result
MAX_RETRY_TIMES = 30
Expand All @@ -136,7 +148,8 @@ def register_provider(subscription=None):
count = 0
while count < MAX_RETRY_TIMES:
time.sleep(RETRY_INTERVAL)
output = run_cli_cmd('az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg))
output = run_cli_cmd(
'az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg))
current_state = output.get('registrationState')
if current_state == 'Registered':
return True
Expand Down Expand Up @@ -171,7 +184,8 @@ def auto_register(func, *args, **kwargs):
# target subscription is not registered, raw check
if ex.error and ex.error.code == 'UnauthorizedResourceAccess' and 'not registered' in ex.error.message:
if 'parameters' in kwargs_backup and 'target_id' in kwargs_backup.get('parameters'):
segments = parse_resource_id(kwargs_backup.get('parameters').get('target_id'))
segments = parse_resource_id(
kwargs_backup.get('parameters').get('target_id'))
target_subs = segments.get('subscription')
# double check whether target subscription is registered
if not provider_is_registered(target_subs):
Expand All @@ -185,8 +199,6 @@ def auto_register(func, *args, **kwargs):

def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, key_vault_id):
from ._validators import get_source_resource_name
from knack.log import get_logger
logger = get_logger(__name__)

logger.warning('get valid key vualt reference connection')
key_vault_connections = []
Expand All @@ -196,15 +208,15 @@ def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, k
key_vault_connections.append(connection)

source_name = get_source_resource_name(cmd)
auth_info = get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, key_vault_connections)
auth_info = get_auth_if_no_valid_key_vault_connection(
source_name, source_id, key_vault_connections)
if not auth_info:
return

# No Valid Key Vault Connection, Create
logger.warning('no valid key vault connection found. Creating...')

from ._resource_config import (
RESOURCE,
CLIENT_TYPE
)

Expand All @@ -215,7 +227,8 @@ def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, k
"id": key_vault_id
},
'auth_info': auth_info,
'client_type': CLIENT_TYPE.Dotnet, # Key Vault Configuration are same across all client types
# Key Vault Configuration are same across all client types
'client_type': CLIENT_TYPE.Dotnet,
}

if source_name == RESOURCE.KubernetesCluster:
Expand All @@ -230,22 +243,23 @@ def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, k
parameters=parameters)


def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, key_vault_connections):
def get_auth_if_no_valid_key_vault_connection(source_name, source_id, key_vault_connections):
auth_type = 'systemAssignedIdentity'
client_id = None
subscription_id = None

if key_vault_connections:
from ._resource_config import RESOURCE
from msrestazure.tools import (
is_valid_resource_id
)

# https://docs.microsoft.com/azure/app-service/app-service-key-vault-references
if source_name == RESOURCE.WebApp:
try:
webapp = run_cli_cmd('az rest -u {}?api-version=2020-09-01 -o json'.format(source_id))
reference_identity = webapp.get('properties').get('keyVaultReferenceIdentity')
webapp = run_cli_cmd(
'az rest -u {}?api-version=2020-09-01 -o json'.format(source_id))
reference_identity = webapp.get(
'properties').get('keyVaultReferenceIdentity')
except Exception as e:
raise ValidationError('{}. Unable to get "properties.keyVaultReferenceIdentity" from {}.'
'Please check your source id is correct.'.format(e, source_id))
Expand All @@ -255,11 +269,13 @@ def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, ke
segments = parse_resource_id(reference_identity)
subscription_id = segments.get('subscription')
try:
identity = webapp.get('identity').get('userAssignedIdentities').get(reference_identity)
identity = webapp.get('identity').get(
'userAssignedIdentities').get(reference_identity)
client_id = identity.get('clientId')
except Exception: # pylint: disable=broad-except
try:
identity = run_cli_cmd('az identity show --ids {} -o json'.format(reference_identity))
identity = run_cli_cmd(
'az identity show --ids {} -o json'.format(reference_identity))
client_id = identity.get('clientId')
except Exception: # pylint: disable=broad-except
pass
Expand All @@ -269,12 +285,14 @@ def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, ke
for connection in key_vault_connections:
auth_info = connection.get('authInfo')
if auth_info.get('clientId') == client_id and auth_info.get('subscriptionId') == subscription_id:
logger.warning('key vualt reference connection: %s', connection.get('id'))
logger.warning(
'key vualt reference connection: %s', connection.get('id'))
return
else: # System Identity
for connection in key_vault_connections:
if connection.get('authInfo').get('authType') == auth_type:
logger.warning('key vualt reference connection: %s', connection.get('id'))
logger.warning(
'key vualt reference connection: %s', connection.get('id'))
return

# any connection with csi enabled is a valid connection
Expand All @@ -286,7 +304,8 @@ def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, ke
return {'authType': 'userAssignedIdentity'}

else:
logger.warning('key vualt reference connection: %s', key_vault_connections[0].get('id'))
logger.warning('key vualt reference connection: %s',
key_vault_connections[0].get('id'))
return

auth_info = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,10 @@ def _infer_webapp(source_id):
return client_type

def _infer_springcloud(source_id):
client_type = None
client_type = CLIENT_TYPE.SpringBoot
try:
segments = parse_resource_id(source_id)
output = run_cli_cmd('az spring-cloud app show -g {} -s {} -n {}'
output = run_cli_cmd('az spring app show -g {} -s {} -n {}'
' -o json'.format(segments.get('resource_group'), segments.get('name'),
segments.get('child_name_1')))
prop_val = output.get('properties')\
Expand All @@ -180,7 +180,7 @@ def _infer_springcloud(source_id):
client_type = None
if 'webapp' in cmd.name:
client_type = _infer_webapp(namespace.source_id)
elif 'spring-cloud' in cmd.name:
elif 'spring-cloud' in cmd.name or 'spring' in cmd.name:
client_type = _infer_springcloud(namespace.source_id)

method = 'detected'
Expand Down Expand Up @@ -247,12 +247,12 @@ def interactive_input(arg, hint):
def get_local_context_value(cmd, arg):
'''Get local context values
'''
groups = ['all', 'cupertino', 'serviceconnector']
groups = ['all', 'cupertino', 'serviceconnector', 'postgres']
arg_map = {
'source_resource_group': ['resource_group_name'],
'target_resource_group': ['resource_group_name'],
'server': ['postgres_server_name'],
'database': ['postgres_database_name'],
'server': ['server_name', "server"],
'database': ['database_name', "database"],
'site': ['webapp_name']
}
for group in groups:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def get_action(self, values, option_string): # pylint: disable=no-self-use
raise ValidationError('Usage error: {} [KEY=VALUE ...]'.format(option_string))
d = {}
for k in properties:
raise ValidationError('Unsupported Key {} is provided for parameter --system-identity')
v = properties[k]
if k.lower() == 'mysql-identity-id':
d['mysql-identity-id'] = v[0]
else:
raise ValidationError('Unsupported Key {} is provided for parameter --system-identity')
d['auth_type'] = 'systemAssignedIdentity'
return d

Expand Down
Loading

0 comments on commit 19b0ebe

Please sign in to comment.