From 19b0ebea8ae9a4bbc4ebaf413518daec807fccb4 Mon Sep 17 00:00:00 2001 From: xfz11 <81600993+xfz11@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:53:30 +0800 Subject: [PATCH] [Service Connector] `az spring-cloud connection create postgres`: Add `--system-identity` for springcloud-postgres connection (#22459) * connect by mi for postgresql and spring cloud * connect by mi for postgresql and spring cloud * update * update * update default auth type * update * set aad user * fix lint * update * update error msg refine output message update enable localcontext update setup update client type and output message add dependencies alphabetically update setup update warning add auto commit update username to lower case add new permissions update command switch to az spring add dependencies Revert "add dependencies" This reverts commit 6ddbf9483a52e3c4169246033d0b8950bbf565e9. update psycopg2 install dependency runtime update to spring command remove linux pg dependency * support webapp * update style * support container app * support postgres flex server * refactor module * refactor credential free * lint and add sql * fix name and improve psql flex * support sql * lint * remove pyodbc dependency * update psql query * support mysql * update sql suffix * fix error message * update postgresql connection * update * fix * add retry for get identity api * update * update connection error message * support mysql identity * fix some bug * remove client type restriction * update sql authentication * update * lint * update sql aad auth to azcli credential * update psql query * remove connection string log * update package * rename mysql-identity-id * revert test * lint * solve comment * update help * fix errorkey * pg_flex: enable uuid-ossp extension * fix mysql grant query --- .../serviceconnector/_credential_free.py | 761 ++++++++++++++++++ .../command_modules/serviceconnector/_help.py | 18 +- .../serviceconnector/_resource_config.py | 24 +- .../serviceconnector/_utils.py | 89 +- .../serviceconnector/_validators.py | 12 +- .../serviceconnector/action.py | 6 +- .../serviceconnector/custom.py | 5 +- .../test_passwordless_connection_scenario.py | 167 ++++ src/azure-cli/requirements.py3.Darwin.txt | 1 + src/azure-cli/requirements.py3.Linux.txt | 1 + src/azure-cli/requirements.py3.windows.txt | 2 + src/azure-cli/setup.py | 1 + 12 files changed, 1030 insertions(+), 57 deletions(-) create mode 100644 src/azure-cli/azure/cli/command_modules/serviceconnector/_credential_free.py create mode 100644 src/azure-cli/azure/cli/command_modules/serviceconnector/tests/latest/test_passwordless_connection_scenario.py diff --git a/src/azure-cli/azure/cli/command_modules/serviceconnector/_credential_free.py b/src/azure-cli/azure/cli/command_modules/serviceconnector/_credential_free.py new file mode 100644 index 00000000000..86aa644ac66 --- /dev/null +++ b/src/azure-cli/azure/cli/command_modules/serviceconnector/_credential_free.py @@ -0,0 +1,761 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import struct +from knack.log import get_logger +from knack.util import CLIError +from msrestazure.tools import parse_resource_id +from azure.cli.core.azclierror import ( + AzureConnectionError, + ValidationError +) +from azure.cli.core.extension.operations import _install_deps_for_psycopg2 +from azure.cli.core.profiles import ResourceType +from azure.cli.core._profile import Profile +from azure.cli.core.commands.client_factory import get_mgmt_service_client +from azure.cli.core.util import random_string +from azure.cli.core.commands import LongRunningOperation +from azure.cli.core.commands.arm import ArmTemplateBuilder +from ._utils import run_cli_cmd, generate_random_string +from ._resource_config import ( + RESOURCE, +) +from ._validators import ( + get_source_resource_name, + get_target_resource_name, +) + +logger = get_logger(__name__) + + +# pylint: disable=line-too-long +# For db(mysqlFlex/psql/psqlFlex/sql) linker with auth type=systemAssignedIdentity, enable AAD auth and create db user on data plane +# For other linker, ignore the steps +def enable_mi_for_db_linker(cmd, source_id, target_id, auth_info, client_type, connection_name): + # return if connection is not for db mi + if auth_info['auth_type'] not in {'systemAssignedIdentity'}: + return + + source_type = get_source_resource_name(cmd) + target_type = get_target_resource_name(cmd) + source_handler = getSourceHandler(source_id, source_type) + if source_handler is None: + return + target_handler = getTargetHandler( + cmd, target_id, target_type, auth_info['auth_type'], client_type, connection_name) + if target_handler is None: + return + + user_info = run_cli_cmd( + 'az ad user show --id {}'.format(target_handler.login_username)) + user_object_id = user_info.get('objectId') if user_info.get('objectId') is not None \ + else user_info.get('id') + if user_object_id is None: + raise Exception( + "No object id found for user {}".format(target_handler.login_username)) + + # enable source mi + source_object_id = source_handler.get_identity_pid() + + identity_info = run_cli_cmd( + 'az ad sp show --id {}'.format(source_object_id), 15, 10) + client_id = identity_info.get('appId') + identity_name = identity_info.get('displayName') + + # enable target aad authentication and set login user as db aad admin + target_handler.enable_target_aad_auth() + target_handler.set_user_admin( + user_object_id, mysql_identity_id=auth_info.get('mysql-identity-id')) + + # create an aad user in db + target_handler.create_aad_user(identity_name, client_id) + return target_handler.get_auth_config() + + +# pylint: disable=no-self-use, unused-argument, too-many-instance-attributes +def getTargetHandler(cmd, target_id, target_type, auth_type, client_type, connection_name): + if target_type in {RESOURCE.Sql}: + return SqlHandler(cmd, target_id, target_type, auth_type, connection_name) + if target_type in {RESOURCE.Postgres}: + return PostgresSingleHandler(cmd, target_id, target_type, auth_type, connection_name) + if target_type in {RESOURCE.PostgresFlexible}: + return PostgresFlexHandler(cmd, target_id, target_type, auth_type, connection_name) + if target_type in {RESOURCE.MysqlFlexible}: + return MysqlFlexibleHandler(cmd, target_id, target_type, auth_type, connection_name) + return None + + +class TargetHandler: + target_id = "" + target_type = "" + profile = None + cmd = None + tenant_id = "" + subscription = "" + resource_group = "" + login_username = "" + endpoint = "" + aad_username = "" + + auth_type = "" + + def __init__(self, cmd, target_id, target_type, auth_type, connection_name): + self.profile = Profile(cli_ctx=cmd.cli_ctx) + self.cmd = cmd + self.target_id = target_id + self.target_type = target_type + self.aad_username = "aad_" + connection_name + self.tenant_id = Profile( + cli_ctx=cmd.cli_ctx).get_subscription().get("tenantId") + target_segments = parse_resource_id(target_id) + self.subscription = target_segments.get('subscription') + self.resource_group = target_segments.get('resource_group') + self.auth_type = auth_type + self.login_username = run_cli_cmd( + 'az account show').get("user").get("name") + + def enable_target_aad_auth(self): + return + + def set_user_admin(self, user_object_id, **kwargs): + return + + def set_target_firewall(self, add_new_rule, ip_name): + return + + def create_aad_user(self, identity_name, client_id): + return + + def get_auth_config(self): + return + + +class MysqlFlexibleHandler(TargetHandler): + + server = "" + dbname = "" + + def __init__(self, cmd, target_id, target_type, auth_type, connection_name): + super().__init__(cmd, target_id, target_type, auth_type, connection_name) + self.endpoint = cmd.cli_ctx.cloud.suffixes.mysql_server_endpoint + target_segments = parse_resource_id(target_id) + self.server = target_segments.get('name') + self.dbname = target_segments.get('child_name_1') + + def set_user_admin(self, user_object_id, **kwargs): + mysql_identity_id = kwargs['mysql_identity_id'] + admins = run_cli_cmd( + 'az mysql flexible-server ad-admin list -g {} -s {} --subscription {}'.format( + self.resource_group, self.server, self.subscription) + ) + is_admin = any(ad.get('sid') == user_object_id for ad in admins) + if is_admin: + return + + logger.warning('Set current user as DB Server AAD Administrators.') + # set user as AAD admin + if mysql_identity_id is None: + raise ValidationError( + "Provide '--system-identity mysql-identity-id=xx' to set {} as AAD administrator.".format(self.user)) + mysql_umi = run_cli_cmd( + 'az mysql flexible-server identity list -g {} -s {} --subscription {}'.format(self.resource_group, self.server, self.subscription)) + if (not mysql_umi) or mysql_identity_id not in mysql_umi.get("userAssignedIdentities"): + run_cli_cmd('az mysql flexible-server identity assign -g {} -s {} --subscription {} --identity {}'.format( + self.resource_group, self.server, self.subscription, mysql_identity_id)) + run_cli_cmd('az mysql flexible-server ad-admin create -g {} -s {} --subscription {} -u {} -i {} --identity {}'.format( + self.resource_group, self.server, self.subscription, self.login_username, user_object_id, mysql_identity_id)) + + def create_aad_user(self, identity_name, client_id): + query_list = self.get_create_query(client_id) + connection_kwargs = self.get_connection_string() + ip_name = None + try: + logger.warning("Connecting to database...") + self.create_aad_user_in_mysql(connection_kwargs, query_list) + except AzureConnectionError: + # allow public access + ip_name = generate_random_string(prefix='svc_').lower() + self.set_target_firewall(True, ip_name) + # create again + self.create_aad_user_in_mysql(connection_kwargs, query_list) + + # remove firewall rule + if ip_name is not None: + try: + self.set_target_firewall(False, ip_name) + # pylint: disable=bare-except + except: + pass + # logger.warning('Please manually delete firewall rule %s to avoid security issue', ipname) + + def set_target_firewall(self, add_new_rule, ip_name): + if add_new_rule: + target = run_cli_cmd( + 'az mysql flexible-server show --ids {}'.format(self.target_id)) + # logger.warning("Update database server firewall rule to connect...") + if target.get('network').get('publicNetworkAccess') == "Disabled": + return + run_cli_cmd( + 'az mysql flexible-server firewall-rule create --resource-group {0} --name {1} --rule-name {2} ' + '--subscription {3} --start-ip-address 0.0.0.0 --end-ip-address 255.255.255.255'.format( + self.resource_group, self.server, ip_name, self.subscription) + ) + # logger.warning("Remove database server firewall rules to recover...") + # run_cli_cmd('az mysql server firewall-rule delete -g {0} -s {1} -n {2} -y'.format(rg, server, ipname)) + # if deny_public_access: + # run_cli_cmd('az mysql server update --public Disabled --ids {}'.format(target_id)) + + def create_aad_user_in_mysql(self, connection_kwargs, query_list): + import pkg_resources + installed_packages = pkg_resources.working_set + # pylint: disable=not-an-iterable + pym_installed = any(('pymysql') in d.key.lower() + for d in installed_packages) + if not pym_installed: + import pip + pip.main(['install', 'mycli']) + # pylint: disable=import-error + import pymysql + from pymysql.constants import CLIENT + + connection_kwargs['client_flag'] = CLIENT.MULTI_STATEMENTS + try: + connection = pymysql.connect(**connection_kwargs) + cursor = connection.cursor() + for q in query_list: + try: + logger.debug(q) + cursor.execute(q) + except Exception as e: # pylint: disable=broad-except + logger.warning( + "Query %s, error: %s", q, str(e)) + except pymysql.Error as e: + raise AzureConnectionError("Fail to connect mysql. " + str(e)) + if cursor is not None: + try: + cursor.close() + except Exception as e: # pylint: disable=broad-except + raise CLIError(str(e)) + + def get_connection_string(self): + password = run_cli_cmd( + 'az account get-access-token --resource-type oss-rdbms').get('accessToken') + + return { + 'host': self.server + self.endpoint, + 'database': self.dbname, + 'user': self.login_username, + 'password': password, + 'ssl': {"fake_flag_to_enable_tls": True}, + 'autocommit': True + } + + def get_create_query(self, client_id): + return [ + "SET aad_auth_validate_oids_in_tenant = OFF;", + "DROP USER IF EXISTS '{}'@'%';".format(self.aad_username), + "CREATE AADUSER '{}' IDENTIFIED BY '{}';".format( + self.aad_username, client_id), + "GRANT ALL PRIVILEGES ON {}.* TO '{}'@'%';".format( + self.dbname, self.aad_username), + "FLUSH privileges;" + ] + + def get_auth_config(self): + if self.auth_type in {'systemAssignedIdentity'}: + return { + 'auth_type': 'secret', + 'name': self.aad_username, + 'secret_info': { + 'secret_type': 'rawValue' + } + } + + +class SqlHandler(TargetHandler): + + server = "" + dbname = "" + + def __init__(self, cmd, target_id, target_type, auth_type, connection_name): + super().__init__(cmd, target_id, target_type, auth_type, connection_name) + self.endpoint = cmd.cli_ctx.cloud.suffixes.sql_server_hostname + target_segments = parse_resource_id(target_id) + self.server = target_segments.get('name') + self.dbname = target_segments.get('child_name_1') + + def set_user_admin(self, user_object_id, **kwargs): + # pylint: disable=not-an-iterable + admins = run_cli_cmd( + 'az sql server ad-admin list --ids {}'.format(self.target_id)) + is_admin = any(ad.get('sid') == user_object_id for ad in admins) + if not is_admin: + logger.warning('Setting current user as database server AAD admin:' + ' user=%s object id=%s', self.login_username, user_object_id) + run_cli_cmd('az sql server ad-admin create -g {} --server-name {} --display-name {} --object-id {} --subscription {}'.format( + self.resource_group, self.server, self.login_username, user_object_id, self.subscription)).get('objectId') + + def create_aad_user(self, identity_name, client_id): + self.aad_username = identity_name + + query_list = self.get_create_query(client_id) + connection_args = self.get_connection_string() + ip_name = None + try: + logger.warning("Connecting to database...") + self.create_aad_user_in_sql(connection_args, query_list) + except AzureConnectionError: + # allow public access + ip_name = generate_random_string(prefix='svc_').lower() + self.set_target_firewall(True, ip_name) + # create again + self.create_aad_user_in_sql(connection_args, query_list) + + # remove firewall rule + if ip_name is not None: + try: + self.set_target_firewall(False, ip_name) + # pylint: disable=bare-except + except: + pass + # logger.warning('Please manually delete firewall rule %s to avoid security issue', ipname) + + def set_target_firewall(self, add_new_rule, ip_name): + if add_new_rule: + target = run_cli_cmd( + 'az sql server show --ids {}'.format(self.target_id)) + # logger.warning("Update database server firewall rule to connect...") + if target.get('publicNetworkAccess') == "Disabled": + run_cli_cmd( + 'az sql server update -e true --ids {}'.format(self.target_id)) + run_cli_cmd( + 'az sql server firewall-rule create -g {0} -s {1} -n {2} ' + '--subscription {3} --start-ip-address 0.0.0.0 --end-ip-address 255.255.255.255'.format( + self.resource_group, self.server, ip_name, self.subscription) + ) + return False + + def create_aad_user_in_sql(self, connection_args, query_list): + import pkg_resources + installed_packages = pkg_resources.working_set + # pylint: disable=not-an-iterable + psy_installed = any(('pyodbc') in d.key.lower() + for d in installed_packages) + + if not psy_installed: + import pip + pip.main(['install', 'pyodbc']) + logger.warning( + "Please manually install odbc 18 for SQL server, reference: https://docs.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16 " + "and run 'pip install pyodbc'") + # pylint: disable=import-error, c-extension-no-member + import pyodbc + try: + with pyodbc.connect(connection_args.get("connection_string"), attrs_before=connection_args.get("attrs_before")) as conn: + with conn.cursor() as cursor: + for execution_query in query_list: + try: + cursor.execute(execution_query) + except pyodbc.ProgrammingError as e: + logger.warning(e) + conn.commit() + except pyodbc.Error as e: + raise AzureConnectionError("Fail to connect sql. " + str(e)) + + def get_connection_string(self): + token_bytes = run_cli_cmd( + 'az account get-access-token --output json --resource https://database.windows.net/').get('accessToken').encode('utf-16-le') + + token_struct = struct.pack(f' is required. Password or account key for secret auth. @@ -235,13 +237,25 @@ def get_source_display_name(sourcename): Usage: --secret ''' if AUTH_TYPE.SecretAuto in auth_types else '' - system_identity_param = ''' + system_identity_param = '' + if AUTH_TYPE.SystemIdentity in auth_types: + if target in {RESOURCE.MysqlFlexible}: + system_identity_param = ''' + - name: --system-identity + short-summary: The system assigned identity auth info + long-summary: | + Usage: --system-identity mysql-identity-id=xx + + mysql-identity-id : Optional. ID of identity used for MySQL flexible server AAD Authentication. Ignore it if you are the server AAD administrator. + ''' + else: + system_identity_param = ''' - name: --system-identity short-summary: The system assigned identity auth info long-summary: | Usage: --system-identity - ''' if AUTH_TYPE.SystemIdentity in auth_types else '' + ''' user_identity_param = ''' - name: --user-identity short-summary: The user assigned identity auth info diff --git a/src/azure-cli/azure/cli/command_modules/serviceconnector/_resource_config.py b/src/azure-cli/azure/cli/command_modules/serviceconnector/_resource_config.py index 1371657498b..5182df0d740 100644 --- a/src/azure-cli/azure/cli/command_modules/serviceconnector/_resource_config.py +++ b/src/azure-cli/azure/cli/command_modules/serviceconnector/_resource_config.py @@ -651,11 +651,11 @@ class CLIENT_TYPE(Enum): # The first one will be used as the default auth type SUPPORTED_AUTH_TYPE = { RESOURCE.WebApp: { - RESOURCE.Postgres: [AUTH_TYPE.Secret], - RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret], + RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], + RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], RESOURCE.Mysql: [AUTH_TYPE.Secret], - RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret], - RESOURCE.Sql: [AUTH_TYPE.Secret], + RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], + RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], RESOURCE.Redis: [AUTH_TYPE.SecretAuto], RESOURCE.RedisEnterprise: [AUTH_TYPE.SecretAuto], @@ -679,11 +679,11 @@ class CLIENT_TYPE(Enum): RESOURCE.ConfluentKafka: [AUTH_TYPE.Secret], }, RESOURCE.SpringCloud: { - RESOURCE.Postgres: [AUTH_TYPE.Secret], - RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret], + RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], + RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], RESOURCE.Mysql: [AUTH_TYPE.Secret], - RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret], - RESOURCE.Sql: [AUTH_TYPE.Secret], + RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], + RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], RESOURCE.Redis: [AUTH_TYPE.SecretAuto], RESOURCE.RedisEnterprise: [AUTH_TYPE.SecretAuto], @@ -735,11 +735,11 @@ class CLIENT_TYPE(Enum): RESOURCE.ConfluentKafka: [AUTH_TYPE.Secret], }, RESOURCE.ContainerApp: { - RESOURCE.Postgres: [AUTH_TYPE.Secret], - RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret], + RESOURCE.Postgres: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], + RESOURCE.PostgresFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], RESOURCE.Mysql: [AUTH_TYPE.Secret], - RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret], - RESOURCE.Sql: [AUTH_TYPE.Secret], + RESOURCE.MysqlFlexible: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], + RESOURCE.Sql: [AUTH_TYPE.Secret, AUTH_TYPE.SystemIdentity], RESOURCE.Redis: [AUTH_TYPE.SecretAuto], RESOURCE.RedisEnterprise: [AUTH_TYPE.SecretAuto], diff --git a/src/azure-cli/azure/cli/command_modules/serviceconnector/_utils.py b/src/azure-cli/azure/cli/command_modules/serviceconnector/_utils.py index d7ce6c6cdd4..d5a39664faa 100644 --- a/src/azure-cli/azure/cli/command_modules/serviceconnector/_utils.py +++ b/src/azure-cli/azure/cli/command_modules/serviceconnector/_utils.py @@ -4,18 +4,24 @@ # -------------------------------------------------------------------------------------------- import time +from knack.log import get_logger from knack.util import todict from msrestazure.tools import parse_resource_id from azure.cli.core.azclierror import ( ValidationError, CLIInternalError ) +from azure.cli.core._profile import Profile from ._resource_config import ( SOURCE_RESOURCES_USERTOKEN, - TARGET_RESOURCES_USERTOKEN + TARGET_RESOURCES_USERTOKEN, + RESOURCE ) +logger = get_logger(__name__) + + def should_load_source(source): '''Check whether to load `az {source} connection` If {source} is an extension (e.g, spring-cloud), load the command group only when {source} is installed @@ -44,7 +50,8 @@ def generate_random_string(length=5, prefix='', lower_only=False, ensure_complex import string if lower_only and ensure_complexity: - raise CLIInternalError('lower_only and ensure_complexity can not both be specified to True') + raise CLIInternalError( + 'lower_only and ensure_complexity can not both be specified to True') if ensure_complexity and length < 8: raise CLIInternalError('ensure_complexity needs length >= 8') @@ -53,7 +60,8 @@ def generate_random_string(length=5, prefix='', lower_only=False, ensure_complex character_set = string.ascii_lowercase while True: - randstr = '{}{}'.format(prefix, ''.join(random.sample(character_set, length))) + randstr = '{}{}'.format(prefix, ''.join( + random.sample(character_set, length))) lowers = [c for c in randstr if c.islower()] uppers = [c for c in randstr if c.isupper()] numbers = [c for c in randstr if c.isnumeric()] @@ -63,28 +71,32 @@ def generate_random_string(length=5, prefix='', lower_only=False, ensure_complex return randstr -def run_cli_cmd(cmd, retry=0): +def run_cli_cmd(cmd, retry=0, interval=0, should_retry_func=None): '''Run a CLI command :param cmd: The CLI command to be executed :param retry: The times to re-try + :param interval: The seconds wait before retry ''' import json import subprocess - output = subprocess.run(cmd, shell=True, check=False, stderr=subprocess.PIPE, stdout=subprocess.PIPE) - if output.returncode != 0: + output = subprocess.run(cmd, shell=True, check=False, + stderr=subprocess.PIPE, stdout=subprocess.PIPE) + logger.debug(output) + if output.returncode != 0 or (should_retry_func and should_retry_func(output)): if retry: - run_cli_cmd(cmd, retry - 1) - else: - raise CLIInternalError('Command execution failed, command is: ' - '{}, error message is: {}'.format(cmd, output.stderr)) - - return json.loads(output.stdout) if output.stdout else None + time.sleep(interval) + return run_cli_cmd(cmd, retry - 1, interval) + raise CLIInternalError('Command execution failed, command is: ' + '{}, error message is: {}'.format(cmd, output.stderr)) + try: + return json.loads(output.stdout) if output.stdout else None + except ValueError: + return output.stdout or None def set_user_token_header(client, cli_ctx): '''Set user token header to work around OBO''' - from azure.cli.core._profile import Profile # pylint: disable=protected-access # HACK: set custom header to work around OBO @@ -92,7 +104,8 @@ def set_user_token_header(client, cli_ctx): creds, _, _ = profile.get_raw_token() client._client._config.headers_policy._headers['x-ms-serviceconnector-user-token'] = creds[1] # HACK: hide token header - client._config.logging_policy.headers_to_redact.append('x-ms-serviceconnector-user-token') + client._config.logging_policy.headers_to_redact.append( + 'x-ms-serviceconnector-user-token') return client @@ -109,16 +122,14 @@ def provider_is_registered(subscription=None): subs_arg = '' if subscription: subs_arg = '--subscription {}'.format(subscription) - output = run_cli_cmd('az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg)) + output = run_cli_cmd( + 'az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg)) if output.get('registrationState') == 'NotRegistered': return False return True def register_provider(subscription=None): - from knack.log import get_logger - logger = get_logger(__name__) - logger.warning('Provider Microsoft.ServiceLinker is not registered, ' 'trying to register. This usually takes 1-2 minutes.') @@ -127,7 +138,8 @@ def register_provider(subscription=None): subs_arg = '--subscription {}'.format(subscription) # register the provider - run_cli_cmd('az provider register -n Microsoft.ServiceLinker {}'.format(subs_arg)) + run_cli_cmd( + 'az provider register -n Microsoft.ServiceLinker {}'.format(subs_arg)) # verify the registration, 30 * 10s polling the result MAX_RETRY_TIMES = 30 @@ -136,7 +148,8 @@ def register_provider(subscription=None): count = 0 while count < MAX_RETRY_TIMES: time.sleep(RETRY_INTERVAL) - output = run_cli_cmd('az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg)) + output = run_cli_cmd( + 'az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg)) current_state = output.get('registrationState') if current_state == 'Registered': return True @@ -171,7 +184,8 @@ def auto_register(func, *args, **kwargs): # target subscription is not registered, raw check if ex.error and ex.error.code == 'UnauthorizedResourceAccess' and 'not registered' in ex.error.message: if 'parameters' in kwargs_backup and 'target_id' in kwargs_backup.get('parameters'): - segments = parse_resource_id(kwargs_backup.get('parameters').get('target_id')) + segments = parse_resource_id( + kwargs_backup.get('parameters').get('target_id')) target_subs = segments.get('subscription') # double check whether target subscription is registered if not provider_is_registered(target_subs): @@ -185,8 +199,6 @@ def auto_register(func, *args, **kwargs): def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, key_vault_id): from ._validators import get_source_resource_name - from knack.log import get_logger - logger = get_logger(__name__) logger.warning('get valid key vualt reference connection') key_vault_connections = [] @@ -196,7 +208,8 @@ def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, k key_vault_connections.append(connection) source_name = get_source_resource_name(cmd) - auth_info = get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, key_vault_connections) + auth_info = get_auth_if_no_valid_key_vault_connection( + source_name, source_id, key_vault_connections) if not auth_info: return @@ -204,7 +217,6 @@ def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, k logger.warning('no valid key vault connection found. Creating...') from ._resource_config import ( - RESOURCE, CLIENT_TYPE ) @@ -215,7 +227,8 @@ def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, k "id": key_vault_id }, 'auth_info': auth_info, - 'client_type': CLIENT_TYPE.Dotnet, # Key Vault Configuration are same across all client types + # Key Vault Configuration are same across all client types + 'client_type': CLIENT_TYPE.Dotnet, } if source_name == RESOURCE.KubernetesCluster: @@ -230,13 +243,12 @@ def create_key_vault_reference_connection_if_not_exist(cmd, client, source_id, k parameters=parameters) -def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, key_vault_connections): +def get_auth_if_no_valid_key_vault_connection(source_name, source_id, key_vault_connections): auth_type = 'systemAssignedIdentity' client_id = None subscription_id = None if key_vault_connections: - from ._resource_config import RESOURCE from msrestazure.tools import ( is_valid_resource_id ) @@ -244,8 +256,10 @@ def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, ke # https://docs.microsoft.com/azure/app-service/app-service-key-vault-references if source_name == RESOURCE.WebApp: try: - webapp = run_cli_cmd('az rest -u {}?api-version=2020-09-01 -o json'.format(source_id)) - reference_identity = webapp.get('properties').get('keyVaultReferenceIdentity') + webapp = run_cli_cmd( + 'az rest -u {}?api-version=2020-09-01 -o json'.format(source_id)) + reference_identity = webapp.get( + 'properties').get('keyVaultReferenceIdentity') except Exception as e: raise ValidationError('{}. Unable to get "properties.keyVaultReferenceIdentity" from {}.' 'Please check your source id is correct.'.format(e, source_id)) @@ -255,11 +269,13 @@ def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, ke segments = parse_resource_id(reference_identity) subscription_id = segments.get('subscription') try: - identity = webapp.get('identity').get('userAssignedIdentities').get(reference_identity) + identity = webapp.get('identity').get( + 'userAssignedIdentities').get(reference_identity) client_id = identity.get('clientId') except Exception: # pylint: disable=broad-except try: - identity = run_cli_cmd('az identity show --ids {} -o json'.format(reference_identity)) + identity = run_cli_cmd( + 'az identity show --ids {} -o json'.format(reference_identity)) client_id = identity.get('clientId') except Exception: # pylint: disable=broad-except pass @@ -269,12 +285,14 @@ def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, ke for connection in key_vault_connections: auth_info = connection.get('authInfo') if auth_info.get('clientId') == client_id and auth_info.get('subscriptionId') == subscription_id: - logger.warning('key vualt reference connection: %s', connection.get('id')) + logger.warning( + 'key vualt reference connection: %s', connection.get('id')) return else: # System Identity for connection in key_vault_connections: if connection.get('authInfo').get('authType') == auth_type: - logger.warning('key vualt reference connection: %s', connection.get('id')) + logger.warning( + 'key vualt reference connection: %s', connection.get('id')) return # any connection with csi enabled is a valid connection @@ -286,7 +304,8 @@ def get_auth_if_no_valid_key_vault_connection(logger, source_name, source_id, ke return {'authType': 'userAssignedIdentity'} else: - logger.warning('key vualt reference connection: %s', key_vault_connections[0].get('id')) + logger.warning('key vualt reference connection: %s', + key_vault_connections[0].get('id')) return auth_info = { diff --git a/src/azure-cli/azure/cli/command_modules/serviceconnector/_validators.py b/src/azure-cli/azure/cli/command_modules/serviceconnector/_validators.py index f9321cce711..0172edb821d 100644 --- a/src/azure-cli/azure/cli/command_modules/serviceconnector/_validators.py +++ b/src/azure-cli/azure/cli/command_modules/serviceconnector/_validators.py @@ -154,10 +154,10 @@ def _infer_webapp(source_id): return client_type def _infer_springcloud(source_id): - client_type = None + client_type = CLIENT_TYPE.SpringBoot try: segments = parse_resource_id(source_id) - output = run_cli_cmd('az spring-cloud app show -g {} -s {} -n {}' + output = run_cli_cmd('az spring app show -g {} -s {} -n {}' ' -o json'.format(segments.get('resource_group'), segments.get('name'), segments.get('child_name_1'))) prop_val = output.get('properties')\ @@ -180,7 +180,7 @@ def _infer_springcloud(source_id): client_type = None if 'webapp' in cmd.name: client_type = _infer_webapp(namespace.source_id) - elif 'spring-cloud' in cmd.name: + elif 'spring-cloud' in cmd.name or 'spring' in cmd.name: client_type = _infer_springcloud(namespace.source_id) method = 'detected' @@ -247,12 +247,12 @@ def interactive_input(arg, hint): def get_local_context_value(cmd, arg): '''Get local context values ''' - groups = ['all', 'cupertino', 'serviceconnector'] + groups = ['all', 'cupertino', 'serviceconnector', 'postgres'] arg_map = { 'source_resource_group': ['resource_group_name'], 'target_resource_group': ['resource_group_name'], - 'server': ['postgres_server_name'], - 'database': ['postgres_database_name'], + 'server': ['server_name', "server"], + 'database': ['database_name', "database"], 'site': ['webapp_name'] } for group in groups: diff --git a/src/azure-cli/azure/cli/command_modules/serviceconnector/action.py b/src/azure-cli/azure/cli/command_modules/serviceconnector/action.py index 5fd1544a2b9..ae5b06c7654 100644 --- a/src/azure-cli/azure/cli/command_modules/serviceconnector/action.py +++ b/src/azure-cli/azure/cli/command_modules/serviceconnector/action.py @@ -118,7 +118,11 @@ def get_action(self, values, option_string): # pylint: disable=no-self-use raise ValidationError('Usage error: {} [KEY=VALUE ...]'.format(option_string)) d = {} for k in properties: - raise ValidationError('Unsupported Key {} is provided for parameter --system-identity') + v = properties[k] + if k.lower() == 'mysql-identity-id': + d['mysql-identity-id'] = v[0] + else: + raise ValidationError('Unsupported Key {} is provided for parameter --system-identity') d['auth_type'] = 'systemAssignedIdentity' return d diff --git a/src/azure-cli/azure/cli/command_modules/serviceconnector/custom.py b/src/azure-cli/azure/cli/command_modules/serviceconnector/custom.py index 86676c64a1f..5396839052c 100644 --- a/src/azure-cli/azure/cli/command_modules/serviceconnector/custom.py +++ b/src/azure-cli/azure/cli/command_modules/serviceconnector/custom.py @@ -30,7 +30,8 @@ set_user_token_header, auto_register ) -# pylint: disable=unused-argument,unsubscriptable-object,unsupported-membership-test +from ._credential_free import enable_mi_for_db_linker +# pylint: disable=unused-argument,unsubscriptable-object,unsupported-membership-test,too-many-statements,too-many-locals logger = get_logger(__name__) @@ -263,6 +264,8 @@ def connection_create(cmd, client, # pylint: disable=too-many-locals,too-many-s 'manually and then create the connection.'.format(str(e))) validate_service_state(parameters) + new_auth_info = enable_mi_for_db_linker(cmd, source_id, target_id, auth_info, client_type, connection_name) + parameters['auth_info'] = new_auth_info if new_auth_info is not None else parameters['auth_info'] return auto_register(sdk_no_wait, no_wait, client.begin_create_or_update, resource_uri=source_id, diff --git a/src/azure-cli/azure/cli/command_modules/serviceconnector/tests/latest/test_passwordless_connection_scenario.py b/src/azure-cli/azure/cli/command_modules/serviceconnector/tests/latest/test_passwordless_connection_scenario.py new file mode 100644 index 00000000000..b06f065f462 --- /dev/null +++ b/src/azure-cli/azure/cli/command_modules/serviceconnector/tests/latest/test_passwordless_connection_scenario.py @@ -0,0 +1,167 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +from azure.cli.core.commands.client_factory import get_subscription_id +from azure.cli.testsdk import ( + ScenarioTest, + live_only +) +from azure.cli.command_modules.serviceconnector._resource_config import ( + RESOURCE, + SOURCE_RESOURCES, + TARGET_RESOURCES +) +from ._test_utils import CredentialReplacer + +@unittest.skip('Need environment prepared') +class PasswordlessConnectionScenarioTest(ScenarioTest): + + def __init__(self, method_name): + super(PasswordlessConnectionScenarioTest, self).__init__( + method_name, + recording_processors=[CredentialReplacer()] + ) + + def test_aad_webapp_sql(self): + self.kwargs.update({ + 'subscription': get_subscription_id(self.cli_ctx), + 'source_resource_group': 'zxf-test', + 'target_resource_group': 'zxf-test', + 'site': 'xf-mi-test', + 'server': 'servicelinker-sql-mi', + 'database': 'clitest' + }) + name = 'testconn' + source_id = SOURCE_RESOURCES.get(RESOURCE.WebApp).format(**self.kwargs) + target_id = TARGET_RESOURCES.get(RESOURCE.Sql).format(**self.kwargs) + connection_id = source_id + "/providers/Microsoft.ServiceLinker/linkers/" + name + + # prepare + self.cmd('webapp identity remove --ids {}'.format(source_id)) + self.cmd('sql server update -e false --ids {}'.format(target_id)) + self.cmd('sql db create -g {target_resource_group} -s {server} -n {database}') + + # create + self.cmd('webapp connection create sql --connection {} --source-id {} --target-id {} ' + '--system-identity --client-type dotnet'.format(name, source_id, target_id)) + # clean + self.cmd('webapp connection delete --id {} --yes'.format(connection_id)) + + # recreate and test + self.cmd('webapp connection create sql --connection {} --source-id {} --target-id {} ' + '--system-identity --client-type dotnet'.format(name, source_id, target_id)) + # clean + self.cmd('webapp connection delete --id {} --yes'.format(connection_id)) + self.cmd('sql db delete -y -g {target_resource_group} -s {server} -n {database}') + + def test_aad_spring_mysqlflexible(self): + self.kwargs.update({ + 'subscription': get_subscription_id(self.cli_ctx), + 'source_resource_group': 'servicelinker-test-linux-group', + 'target_resource_group': 'zxf-test', + 'spring': 'springeuap', + 'app': 'mysqlflexmi', + 'deployment': 'default', + 'server': 'xf-mysqlflex-test', + 'database': 'mysqlDB', + }) + mysql_identity_id = '/subscriptions/d82d7763-8e12-4f39-a7b6-496a983ec2f4/resourcegroups/zxf-test/providers/Microsoft.ManagedIdentity/userAssignedIdentities/servicelinker-aad-umi' + + # prepare params + name = 'testconn' + source_id = SOURCE_RESOURCES.get(RESOURCE.SpringCloud).format(**self.kwargs) + target_id = TARGET_RESOURCES.get(RESOURCE.MysqlFlexible).format(**self.kwargs) + connection_id = source_id + "/providers/Microsoft.ServiceLinker/linkers/" + name + + # prepare + self.cmd('spring app identity remove -n {app} -s {spring} -g {source_resource_group} --system-assigned') + self.cmd('mysql flexible-server ad-admin delete -g {target_resource_group} -s {server} -y') + self.cmd('mysql flexible-server db create -g {target_resource_group} --server-name {server} --database-name {database}') + # self.cmd('mysql flexible-server identity remove -g {target_resource_group} -s {server} -y --identity ' + mysql_identity_id) + + # create connection + self.cmd('spring connection create mysql-flexible --connection {} --source-id {} --target-id {} ' + '--client-type springboot --system-identity mysql-identity-id={}'.format(name, source_id, target_id, mysql_identity_id)) + # delete connection + self.cmd('spring connection delete --id {} --yes'.format(connection_id)) + + + # create connection + self.cmd('spring connection create mysql-flexible --connection {} --source-id {} --target-id {} ' + '--client-type springboot --system-identity mysql-identity-id={}'.format(name, source_id, target_id, mysql_identity_id)) + # delete connection + self.cmd('spring connection delete --id {} --yes'.format(connection_id)) + self.cmd('mysql flexible-server db delete -y -g {target_resource_group} --server-name {server} --database-name {database}') + + def test_aad_containerapp_postgresflexible(self): + default_container_name = 'simple-hello-world-container' + self.kwargs.update({ + 'subscription': get_subscription_id(self.cli_ctx), + 'source_resource_group': 'zxf-test', + 'target_resource_group': 'zxf-test', + 'app': 'servicelinker-mysql-aca', + 'server': 'xf-pgflex-clitest', + 'database': 'testdb1', + 'containerapp_env': '/subscriptions/d82d7763-8e12-4f39-a7b6-496a983ec2f4/resourceGroups/container-app/providers/Microsoft.App/managedEnvironments/north-europe' + }) + + # prepare params + name = 'testconn' + source_id = SOURCE_RESOURCES.get(RESOURCE.ContainerApp).format(**self.kwargs) + target_id = TARGET_RESOURCES.get(RESOURCE.PostgresFlexible).format(**self.kwargs) + connection_id = source_id + "/providers/Microsoft.ServiceLinker/linkers/" + name + + # prepare + self.cmd('containerapp delete -n {app} -g {source_resource_group}') + self.cmd('containerapp create -n {app} -g {source_resource_group} --environment {containerapp_env} --image nginx') + self.cmd('postgres flexible-server delete -y -g {target_resource_group} -n {server}') + self.cmd('postgres flexible-server create -y -g {target_resource_group} -n {server}') + self.cmd('postgres flexible-server db create -g {target_resource_group} -s {server} -d {database}') + + # create + self.cmd('containerapp connection create postgres-flexible --connection {} --source-id {} --target-id {} ' + '--system-identity --client-type springboot -c {}'.format(name, source_id, target_id, default_container_name)) + configs = self.cmd('containerapp connection list-configuration --id {}'.format(connection_id)).get_output_in_json(); + # clean + self.cmd('containerapp connection delete --id {} --yes'.format(connection_id)) + # + # # recreate and test + # self.cmd('containerapp connection create postgres-flexible --connection {} --source-id {} --target-id {} ' + # '--system-identity --client-type dotnet -c {}'.format(name, source_id, target_id, default_container_name)) + # clean + # self.cmd('containerapp connection delete --id {} --yes'.format(connection_id)) + # self.cmd('postgres flexible-server delete -y -g {target_resource_group} -n {server}') + + + def test_aad_webapp_postgressingle(self): + self.kwargs.update({ + 'subscription': "d82d7763-8e12-4f39-a7b6-496a983ec2f4", + 'source_resource_group': 'zxf-test', + 'target_resource_group': 'zxf-test', + 'site': 'xf-pg-app', + 'server': 'xfpostgre', + 'database': 'testdb' + }) + + # prepare params + name = 'testconn' + source_id = SOURCE_RESOURCES.get(RESOURCE.WebApp).format(**self.kwargs) + target_id = TARGET_RESOURCES.get(RESOURCE.PostgresFlexible).format(**self.kwargs) + connection_id = source_id + "/providers/Microsoft.ServiceLinker/linkers/" + name + + # prepare + self.cmd('webapp identity remove --ids {}'.format(source_id)) + # self.cmd('postgres server delete -y -g {target_resource_group} -n {server}') + # self.cmd('postgres server create -y -g {target_resource_group} -n {server}') + self.cmd('postgres db delete -g {target_resource_group} -s {server} -n {database}') + self.cmd('postgres db create -g {target_resource_group} -s {server} -n {database}') + + # create + self.cmd('webapp connection create postgres-flexible --connection {} --source-id {} --target-id {} ' + '--system-identity --client-type springboot'.format(name, source_id, target_id)) + configs = self.cmd('webapp connection list-configuration --id {}'.format(connection_id)).get_output_in_json(); + diff --git a/src/azure-cli/requirements.py3.Darwin.txt b/src/azure-cli/requirements.py3.Darwin.txt index 39de49164da..5db878c15d0 100644 --- a/src/azure-cli/requirements.py3.Darwin.txt +++ b/src/azure-cli/requirements.py3.Darwin.txt @@ -113,6 +113,7 @@ msal-extensions==1.0.0 msal[broker]==1.20.0b1 msrest==0.7.1 msrestazure==0.6.4 +mycli==1.22.2 oauthlib==3.0.1 packaging==21.3 paramiko==2.10.1 diff --git a/src/azure-cli/requirements.py3.Linux.txt b/src/azure-cli/requirements.py3.Linux.txt index a3849c6b10a..0ac9279f52c 100644 --- a/src/azure-cli/requirements.py3.Linux.txt +++ b/src/azure-cli/requirements.py3.Linux.txt @@ -114,6 +114,7 @@ msal-extensions==1.0.0 msal[broker]==1.20.0b1 msrest==0.7.1 msrestazure==0.6.4 +mycli==1.22.2 oauthlib==3.0.1 packaging==21.3 paramiko==2.10.1 diff --git a/src/azure-cli/requirements.py3.windows.txt b/src/azure-cli/requirements.py3.windows.txt index c7b898e9efd..a352708ff09 100644 --- a/src/azure-cli/requirements.py3.windows.txt +++ b/src/azure-cli/requirements.py3.windows.txt @@ -113,6 +113,7 @@ msal-extensions==1.0.0 msal[broker]==1.20.0b1 msrest==0.7.1 msrestazure==0.6.4 +mycli==1.22.2 oauthlib==3.0.1 packaging==21.3 paramiko==2.10.1 @@ -120,6 +121,7 @@ pbr==5.3.1 pkginfo==1.8.2 portalocker==2.3.2 psutil==5.9.0 +psycopg2==2.9.3 pycparser==2.19 PyGithub==1.55 PyJWT==2.4.0 diff --git a/src/azure-cli/setup.py b/src/azure-cli/setup.py index bae49794673..623cdaed6d2 100644 --- a/src/azure-cli/setup.py +++ b/src/azure-cli/setup.py @@ -142,6 +142,7 @@ 'fabric~=2.4', 'javaproperties~=0.5.1', 'jsondiff~=2.0.0', + 'mycli~=1.22.2', 'packaging>=20.9,<22.0', 'PyGithub~=1.38', 'PyNaCl~=1.5.0',