Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Service Connector] az spring-cloud connection create postgres: Add --system-identity for springcloud-postgres connection #22459

Merged
merged 57 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
8e68ad6
connect by mi for postgresql and spring cloud
xfz11 May 16, 2022
91e293e
connect by mi for postgresql and spring cloud
xfz11 May 16, 2022
a4469cc
Merge branch 'postgresmi' of github.com:xfz11/azure-cli into postgresmi
xfz11 May 16, 2022
7072cc4
update
xfz11 May 16, 2022
09e9b0b
update
xfz11 May 16, 2022
91ecdb4
update default auth type
xfz11 May 16, 2022
528d09a
update
xfz11 May 16, 2022
c8c5acc
set aad user
xfz11 May 17, 2022
31d124a
fix lint
xfz11 May 17, 2022
bd96280
Merge remote-tracking branch 'upstream/dev' into postgresmi
xfz11 May 17, 2022
855cb5a
update
xfz11 May 17, 2022
0e9959e
update error msg
xfz11 May 17, 2022
43729ba
Merge branch 'dev' into postgresmi
xfz11 Jul 22, 2022
eb0f1e0
support webapp
xfz11 Jul 22, 2022
c201b9e
update style
xfz11 Jul 22, 2022
037901b
support container app
xfz11 Aug 8, 2022
6db5cb4
support postgres flex server
xfz11 Aug 19, 2022
b8526cd
Merge branch 'postgresmi' of github.com:xfz11/azure-cli into postgresmi
xfz11 Aug 19, 2022
9d41073
Merge remote-tracking branch 'upstream/dev' into postgresmi
xfz11 Aug 19, 2022
086f03c
refactor module
xfz11 Aug 26, 2022
b46c906
refactor credential free
xfz11 Aug 29, 2022
fa384ff
Merge branch 'dev' into postgresmi
xfz11 Aug 29, 2022
14c191e
lint and add sql
xfz11 Aug 31, 2022
5b107ce
fix name and improve psql flex
xfz11 Aug 31, 2022
b568a24
support sql
xfz11 Aug 31, 2022
f909b63
lint
xfz11 Aug 31, 2022
f93f56b
remove pyodbc dependency
xfz11 Sep 2, 2022
07b6d48
update psql query
xfz11 Sep 5, 2022
95b5fee
support mysql
xfz11 Sep 5, 2022
3645afb
update sql suffix
xfz11 Sep 5, 2022
e09feb4
fix error message
xfz11 Sep 5, 2022
e98a7c8
update postgresql connection
xfz11 Sep 5, 2022
ce2133d
update
xfz11 Sep 5, 2022
69ed73e
fix
xfz11 Sep 5, 2022
1116e6e
add retry for get identity api
xfz11 Sep 6, 2022
9a6f8e4
update
xfz11 Sep 6, 2022
6c81f84
update connection error message
xfz11 Sep 7, 2022
fd6784a
Merge branch 'dev' into postgresmi
xfz11 Sep 7, 2022
69c0d33
support mysql identity
xfz11 Sep 7, 2022
96cc8c5
fix some bug
xfz11 Sep 8, 2022
eebb424
remove client type restriction
xfz11 Sep 9, 2022
254b4bf
update sql authentication
xfz11 Sep 9, 2022
fa0e8aa
update
xfz11 Sep 9, 2022
9d0aec1
lint
xfz11 Sep 9, 2022
7706291
update sql aad auth to azcli credential
xfz11 Sep 13, 2022
d7536e4
update psql query
xfz11 Sep 13, 2022
940b8e4
remove connection string log
xfz11 Sep 13, 2022
5bbfceb
update package
xfz11 Sep 14, 2022
018a367
rename mysql-identity-id
xfz11 Sep 14, 2022
4b86e51
revert test
xfz11 Sep 14, 2022
77888c7
lint
xfz11 Sep 14, 2022
d3181c7
solve comment
xfz11 Sep 15, 2022
c9696b9
Merge branch 'dev' into postgresmi
xfz11 Sep 15, 2022
89c56a8
update help
xfz11 Sep 15, 2022
a378cdd
fix errorkey
xfz11 Sep 16, 2022
3d8a8d5
pg_flex: enable uuid-ossp extension
xfz11 Sep 16, 2022
a6ed70d
fix mysql grant query
xfz11 Sep 19, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def add_new_addon_argument(context, source, target):
def add_secret_store_argument(context):
context.argument('key_vault_id', options_list=['--vault-id'], help='The id of key vault to store secret value')

def add_mysql_umi_argument(context, target):
if target in [RESOURCE.MysqlFlexible]:
context.argument('mysql_identity_id', options_list=['--mysql-identity-id'],
xfz11 marked this conversation as resolved.
Show resolved Hide resolved
help='The ID of identity used for MySQL flexible server AAD Authentication')
else:
c.ignore('mysql_identity_id')

def add_vnet_block(context, target):
if target not in TARGET_SUPPORT_SERVICE_ENDPOINT:
c.ignore('service_endpoint')
Expand Down Expand Up @@ -206,6 +213,7 @@ def add_confluent_kafka_argument(context):
add_secret_store_argument(c)
add_vnet_block(c, target)
add_connection_string_argument(c, source, target)
add_mysql_umi_argument(c, target)
with self.argument_context('{} connection update {}'.format(source.value, target.value)) as c:
add_client_type_argument(c, source, target)
add_connection_name_argument(c, source)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,11 @@ class CLIENT_TYPE(Enum):
# The first one will be used as the default auth type
SUPPORTED_AUTH_TYPE = {
RESOURCE.WebApp: {
RESOURCE.Postgres: [AUTH_TYPE.Secret],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret],
RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Mysql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret],
RESOURCE.Sql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Redis: [AUTH_TYPE.SecretAuto],
RESOURCE.RedisEnterprise: [AUTH_TYPE.SecretAuto],

Expand All @@ -679,11 +679,11 @@ class CLIENT_TYPE(Enum):
RESOURCE.ConfluentKafka: [AUTH_TYPE.Secret],
},
RESOURCE.SpringCloud: {
RESOURCE.Postgres: [AUTH_TYPE.Secret],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret],
RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Mysql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret],
RESOURCE.Sql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Redis: [AUTH_TYPE.SecretAuto],
RESOURCE.RedisEnterprise: [AUTH_TYPE.SecretAuto],

Expand Down Expand Up @@ -735,11 +735,11 @@ class CLIENT_TYPE(Enum):
RESOURCE.ConfluentKafka: [AUTH_TYPE.Secret],
},
RESOURCE.ContainerApp: {
RESOURCE.Postgres: [AUTH_TYPE.Secret],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret],
RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Mysql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret],
RESOURCE.Sql: [AUTH_TYPE.Secret],
RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity],
RESOURCE.Redis: [AUTH_TYPE.SecretAuto],
RESOURCE.RedisEnterprise: [AUTH_TYPE.SecretAuto],

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@
# --------------------------------------------------------------------------------------------

import time
from knack.log import get_logger
from knack.util import todict
from msrestazure.tools import parse_resource_id
from azure.cli.core.azclierror import (
ValidationError,
CLIInternalError
)
from azure.cli.core._profile import Profile
from ._resource_config import (
SOURCE_RESOURCES_USERTOKEN,
TARGET_RESOURCES_USERTOKEN
TARGET_RESOURCES_USERTOKEN,
RESOURCE
)
# pylint: disable=unused-argument, not-an-iterable, too-many-statements
xfz11 marked this conversation as resolved.
Show resolved Hide resolved


logger = get_logger(__name__)


def should_load_source(source):
Expand Down Expand Up @@ -44,7 +51,8 @@ def generate_random_string(length=5, prefix='', lower_only=False, ensure_complex
import string

if lower_only and ensure_complexity:
raise CLIInternalError('lower_only and ensure_complexity can not both be specified to True')
raise CLIInternalError(
'lower_only and ensure_complexity can not both be specified to True')
if ensure_complexity and length < 8:
raise CLIInternalError('ensure_complexity needs length >= 8')

Expand All @@ -53,46 +61,52 @@ def generate_random_string(length=5, prefix='', lower_only=False, ensure_complex
character_set = string.ascii_lowercase

while True:
randstr = '{}{}'.format(prefix, ''.join(random.sample(character_set, length)))
randstr = '{}{}'.format(prefix, ''.join(
random.sample(character_set, length)))
lowers = [c for c in randstr if c.islower()]
uppers = [c for c in randstr if c.isupper()]
numbers = [c for c in randstr if c.isnumeric()]
if not ensure_complexity or (lowers and uppers and numbers):
break

return randstr
return randstr.lower()
xfz11 marked this conversation as resolved.
Show resolved Hide resolved


def run_cli_cmd(cmd, retry=0):
def run_cli_cmd(cmd, retry=0, interval=0, should_retry_func=None):
'''Run a CLI command
:param cmd: The CLI command to be executed
:param retry: The times to re-try
:param interval: The seconds wait before retry
'''
import json
import subprocess

output = subprocess.run(cmd, shell=True, check=False, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
if output.returncode != 0:
output = subprocess.run(cmd, shell=True, check=False,
stderr=subprocess.PIPE, stdout=subprocess.PIPE)
logger.debug(output)
if output.returncode != 0 or (should_retry_func and should_retry_func(output)):
if retry:
run_cli_cmd(cmd, retry - 1)
else:
raise CLIInternalError('Command execution failed, command is: '
'{}, error message is: {}'.format(cmd, output.stderr))

return json.loads(output.stdout) if output.stdout else None
time.sleep(interval)
return run_cli_cmd(cmd, retry - 1, interval)
raise CLIInternalError('Command execution failed, command is: '
'{}, error message is: {}'.format(cmd, output.stderr))
try:
return json.loads(output.stdout) if output.stdout else None
except ValueError:
return output.stdout or None


def set_user_token_header(client, cli_ctx):
'''Set user token header to work around OBO'''
from azure.cli.core._profile import Profile

# pylint: disable=protected-access
# HACK: set custom header to work around OBO
profile = Profile(cli_ctx=cli_ctx)
creds, _, _ = profile.get_raw_token()
client._client._config.headers_policy._headers['x-ms-serviceconnector-user-token'] = creds[1]
# HACK: hide token header
client._config.logging_policy.headers_to_redact.append('x-ms-serviceconnector-user-token')
client._config.logging_policy.headers_to_redact.append(
'x-ms-serviceconnector-user-token')

return client

Expand All @@ -109,16 +123,14 @@ def provider_is_registered(subscription=None):
subs_arg = ''
if subscription:
subs_arg = '--subscription {}'.format(subscription)
output = run_cli_cmd('az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg))
output = run_cli_cmd(
'az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg))
if output.get('registrationState') == 'NotRegistered':
return False
return True


def register_provider(subscription=None):
from knack.log import get_logger
logger = get_logger(__name__)

logger.warning('Provider Microsoft.ServiceLinker is not registered, '
'trying to register. This usually takes 1-2 minutes.')

Expand All @@ -127,7 +139,8 @@ def register_provider(subscription=None):
subs_arg = '--subscription {}'.format(subscription)

# register the provider
run_cli_cmd('az provider register -n Microsoft.ServiceLinker {}'.format(subs_arg))
run_cli_cmd(
'az provider register -n Microsoft.ServiceLinker {}'.format(subs_arg))

# verify the registration, 30 * 10s polling the result
MAX_RETRY_TIMES = 30
Expand All @@ -136,7 +149,8 @@ def register_provider(subscription=None):
count = 0
while count < MAX_RETRY_TIMES:
time.sleep(RETRY_INTERVAL)
output = run_cli_cmd('az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg))
output = run_cli_cmd(
'az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg))
current_state = output.get('registrationState')
if current_state == 'Registered':
return True
Expand Down Expand Up @@ -171,7 +185,8 @@ def auto_register(func, *args, **kwargs):
# target subscription is not registered, raw check
if ex.error and ex.error.code == 'UnauthorizedResourceAccess' and 'not registered' in ex.error.message:
if 'parameters' in kwargs_backup and 'target_id' in kwargs_backup.get('parameters'):
segments = parse_resource_id(kwargs_backup.get('parameters').get('target_id'))
segments = parse_resource_id(
kwargs_backup.get('parameters').get('target_id'))
target_subs = segments.get('subscription')
# double check whether target subscription is registered
if not provider_is_registered(target_subs):
Expand All @@ -185,8 +200,6 @@ def auto_register(func, *args, **kwargs):

def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, key_vault_id):
from ._validators import get_source_resource_name
from knack.log import get_logger
logger = get_logger(__name__)

logger.warning('get valid key vualt reference connection')
key_vault_connections = []
Expand All @@ -196,15 +209,15 @@ def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, k
key_vault_connections.append(connection)

source_name = get_source_resource_name(cmd)
auth_info = get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, key_vault_connections)
auth_info = get_auth_if_no_valid_key_vault_connection(
source_name, source_id, key_vault_connections)
if not auth_info:
return

# No Valid Key Vault Connection, Create
logger.warning('no valid key vault connection found. Creating...')

from ._resource_config import (
RESOURCE,
CLIENT_TYPE
)

Expand All @@ -215,7 +228,8 @@ def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, k
"id": key_vault_id
},
'auth_info': auth_info,
'client_type': CLIENT_TYPE.Dotnet, # Key Vault Configuration are same across all client types
# Key Vault Configuration are same across all client types
'client_type': CLIENT_TYPE.Dotnet,
}

if source_name == RESOURCE.KubernetesCluster:
Expand All @@ -230,22 +244,23 @@ def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, k
parameters=parameters)


def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, key_vault_connections):
def get_auth_if_no_valid_key_vault_connection(source_name, source_id, key_vault_connections):
auth_type = 'systemAssignedIdentity'
client_id = None
subscription_id = None

if key_vault_connections:
from ._resource_config import RESOURCE
from msrestazure.tools import (
is_valid_resource_id
)

# https://docs.microsoft.com/azure/app-service/app-service-key-vault-references
if source_name == RESOURCE.WebApp:
try:
webapp = run_cli_cmd('az rest -u {}?api-version=2020-09-01 -o json'.format(source_id))
reference_identity = webapp.get('properties').get('keyVaultReferenceIdentity')
webapp = run_cli_cmd(
'az rest -u {}?api-version=2020-09-01 -o json'.format(source_id))
reference_identity = webapp.get(
'properties').get('keyVaultReferenceIdentity')
except Exception as e:
raise ValidationError('{}. Unable to get "properties.keyVaultReferenceIdentity" from {}.'
'Please check your source id is correct.'.format(e, source_id))
Expand All @@ -255,11 +270,13 @@ def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, ke
segments = parse_resource_id(reference_identity)
subscription_id = segments.get('subscription')
try:
identity = webapp.get('identity').get('userAssignedIdentities').get(reference_identity)
identity = webapp.get('identity').get(
'userAssignedIdentities').get(reference_identity)
client_id = identity.get('clientId')
except Exception: # pylint: disable=broad-except
try:
identity = run_cli_cmd('az identity show --ids {} -o json'.format(reference_identity))
identity = run_cli_cmd(
'az identity show --ids {} -o json'.format(reference_identity))
client_id = identity.get('clientId')
except Exception: # pylint: disable=broad-except
pass
Expand All @@ -269,12 +286,14 @@ def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, ke
for connection in key_vault_connections:
auth_info = connection.get('authInfo')
if auth_info.get('clientId') == client_id and auth_info.get('subscriptionId') == subscription_id:
logger.warning('key vualt reference connection: %s', connection.get('id'))
logger.warning(
'key vualt reference connection: %s', connection.get('id'))
return
else: # System Identity
for connection in key_vault_connections:
if connection.get('authInfo').get('authType') == auth_type:
logger.warning('key vualt reference connection: %s', connection.get('id'))
logger.warning(
'key vualt reference connection: %s', connection.get('id'))
return

# any connection with csi enabled is a valid connection
Expand All @@ -286,7 +305,8 @@ def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, ke
return {'authType': 'userAssignedIdentity'}

else:
logger.warning('key vualt reference connection: %s', key_vault_connections[0].get('id'))
logger.warning('key vualt reference connection: %s',
key_vault_connections[0].get('id'))
return

auth_info = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,10 @@ def _infer_webapp(source_id):
return client_type

def _infer_springcloud(source_id):
client_type = None
client_type = CLIENT_TYPE.SpringBoot
try:
segments = parse_resource_id(source_id)
output = run_cli_cmd('az spring-cloud app show -g {} -s {} -n {}'
output = run_cli_cmd('az spring app show -g {} -s {} -n {}'
' -o json'.format(segments.get('resource_group'), segments.get('name'),
segments.get('child_name_1')))
prop_val = output.get('properties')\
Expand All @@ -180,7 +180,7 @@ def _infer_springcloud(source_id):
client_type = None
if 'webapp' in cmd.name:
client_type = _infer_webapp(namespace.source_id)
elif 'spring-cloud' in cmd.name:
elif 'spring-cloud' in cmd.name or 'spring' in cmd.name:
client_type = _infer_springcloud(namespace.source_id)

method = 'detected'
Expand Down Expand Up @@ -247,12 +247,12 @@ def interactive_input(arg, hint):
def get_local_context_value(cmd, arg):
'''Get local context values
'''
groups = ['all', 'cupertino', 'serviceconnector']
groups = ['all', 'cupertino', 'serviceconnector', 'postgres']
arg_map = {
'source_resource_group': ['resource_group_name'],
'target_resource_group': ['resource_group_name'],
'server': ['postgres_server_name'],
'database': ['postgres_database_name'],
'server': ['server_name', "server"],
'database': ['database_name', "database"],
'site': ['webapp_name']
}
for group in groups:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
set_user_token_header,
auto_register
)
# pylint: disable=unused-argument,unsubscriptable-object,unsupported-membership-test
from ._credential_free import enable_mi_for_db_linker
# pylint: disable=unused-argument,unsubscriptable-object,unsupported-membership-test,too-many-statements,too-many-locals


logger = get_logger(__name__)
Expand Down Expand Up @@ -159,6 +160,7 @@ def connection_create(cmd, client, # pylint: disable=too-many-locals disable=to
service_endpoint=None,
private_endpoint=None,
store_in_connection_string=False,
mysql_identity_id=None,
new_addon=False, no_wait=False,
cluster=None, scope=None, enable_csi=False, # Resource.KubernetesCluster
site=None, # Resource.WebApp
Expand Down Expand Up @@ -263,6 +265,9 @@ def connection_create(cmd, client, # pylint: disable=too-many-locals disable=to
'manually and then create the connection.'.format(str(e)))

validate_service_state(parameters)
new_auth_info = enable_mi_for_db_linker(cmd, source_id, target_id, auth_info, source_type, target_type, client_type,
xfz11 marked this conversation as resolved.
Show resolved Hide resolved
connection_name, mysql_identity_id)
parameters['auth_info'] = new_auth_info if new_auth_info is not None else parameters['auth_info']
return auto_register(sdk_no_wait, no_wait,
client.begin_create_or_update,
resource_uri=source_id,
Expand Down
Loading