Skip to content

Commit

Permalink
{Service Connector} Add input validation for run_cli_cmd (#29798)
Browse files Browse the repository at this point in the history
* fix run_cli_cmd

* add input validator

* fix no attribute error

* fix add_target_resource_block

* fix no attribute

* update confluent

* lint

* update is_valid_resource_id

* update

* update validate_source_id

* lint

* remove msrestazure
  • Loading branch information
xfz11 authored Sep 4, 2024
1 parent 0f4aefd commit 6213de3
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import re
from knack.log import get_logger
from msrestazure.tools import (
from azure.mgmt.core.tools import (
parse_resource_id,
is_valid_resource_id
)
from azure.cli.core import telemetry
from azure.cli.core.commands.client_factory import get_subscription_id
Expand All @@ -17,7 +16,8 @@
)
from ._utils import (
generate_random_string,
run_cli_cmd
run_cli_cmd,
is_valid_resource_id
)
from ._resource_config import (
RESOURCE,
Expand All @@ -42,12 +42,13 @@
AddonConfig = {
RESOURCE.Postgres: {
'create': [
'az postgres server create -g {target_resource_group} -n {server} -l {location} -u {user} -p {password}',
'az postgres db create -g {target_resource_group} -s {server} -n {database}'
'az postgres server create -g "{target_resource_group}" -n "{server}" -l "{location}" -u "{user}" \
-p "{password}"',
'az postgres db create -g "{target_resource_group}" -s "{server}" -n {database}'
],
'delete': [
'az postgres server delete -g {target_resource_group} -n {server} --yes',
'az postgres db delete -g {target_resource_group} -s {server} -n {database} --yes'
'az postgres server delete -g "{target_resource_group}" -n "{server}" --yes',
'az postgres db delete -g "{target_resource_group}" -s "{server}" -n "{database}" --yes'
],
'params': {
'target_resource_group': '_retrive_source_rg',
Expand All @@ -59,17 +60,17 @@
}
},
RESOURCE.KeyVault: {
'create': ['az keyvault create -g {target_resource_group} -n {vault} -l {location}'],
'delete': ['az keyvault delete -g {target_resource_group} -n {vault} --yes'],
'create': ['az keyvault create -g "{target_resource_group}" -n "{vault}" -l "{location}"'],
'delete': ['az keyvault delete -g "{target_resource_group}" -n "{vault}" --yes'],
'params': {
'target_resource_group': '_retrive_source_rg',
'location': '_retrive_source_loc',
'vault': generate_random_string(length=5, prefix='vault-')
}
},
RESOURCE.StorageBlob: {
'create': ['az storage account create -g {target_resource_group} -n {account} -l {location}'],
'delete': ['az storage account delete -g {target_resource_group} -n {account} --yes'],
'create': ['az storage account create -g "{target_resource_group}" -n "{account}" -l "{location}"'],
'delete': ['az storage account delete -g "{target_resource_group}" -n "{account}" --yes'],
'params': {
'target_resource_group': '_retrive_source_rg',
'location': '_retrive_source_loc',
Expand Down Expand Up @@ -186,7 +187,7 @@ def _retrive_source_loc(self):
'''Retrieve the location of source resource group
'''
rg = self._retrive_source_rg()
output = run_cli_cmd('az group show -n {} -o json'.format(rg))
output = run_cli_cmd('az group show -n "{}" -o json'.format(rg))
return output.get('location')

def _get_source_type(self):
Expand Down
21 changes: 11 additions & 10 deletions src/azure-cli/azure/cli/command_modules/serviceconnector/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .action import AddCustomizedKeys


def add_source_resource_block(context, source, enable_id=True, validate_source_id=False):
def add_source_resource_block(context, source, enable_id=True):
source_args = SOURCE_RESOURCES_PARAMS.get(source)
for resource, args in SOURCE_RESOURCES_PARAMS.items():
if resource != source:
Expand All @@ -57,7 +57,7 @@ def add_source_resource_block(context, source, enable_id=True, validate_source_i
required_args.append(content.get('options')[0])

validator_kwargs = {
'validator': validate_params} if validate_source_id else {}
'validator': validate_params}
if not enable_id:
context.argument('source_id', options_list=['--source-id'], type=str,
help="The resource id of a {source}. Required if {required_args} "
Expand Down Expand Up @@ -140,14 +140,15 @@ def add_target_resource_block(context, target):
context.ignore(arg)

required_args = []
for arg, content in TARGET_RESOURCES_PARAMS.get(target).items():
context.argument(arg, options_list=content.get('options'), type=str,
help='{}. Required if \'--target-id\' is not specified.'.format(content.get('help')))
required_args.append(content.get('options')[0])
if target in TARGET_RESOURCES_PARAMS:
for arg, content in TARGET_RESOURCES_PARAMS.get(target).items():
context.argument(arg, options_list=content.get('options'), type=str,
help='{}. Required if \'--target-id\' is not specified.'.format(content.get('help')))
required_args.append(content.get('options')[0])

context.argument('target_id', type=str,
help='The resource id of target service. Required if {required_args} '
'are not specified.'.format(required_args=str(required_args)))
context.argument('target_id', type=str,
help='The resource id of target service. Required if {required_args} '
'are not specified.'.format(required_args=str(required_args)))

if target != RESOURCE.KeyVault:
context.ignore('enable_csi')
Expand Down Expand Up @@ -262,7 +263,7 @@ def load_arguments(self, _): # pylint: disable=too-many-statements

with self.argument_context('{} connection list'.format(source.value)) as c:
add_source_resource_block(
c, source, enable_id=False, validate_source_id=True)
c, source, enable_id=False)

with self.argument_context('{} connection show'.format(source.value)) as c:
add_source_resource_block(c, source)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def transform_linker_properties(result):
if is_aks_linker_by_id(resource_id):
result['kubernetesResourceName'] = get_aks_resource_name(result)
try:
output = run_cli_cmd('az webapp connection list-configuration --id {} -o json'.format(resource_id))
output = run_cli_cmd('az webapp connection list-configuration --id "{}" -o json'.format(resource_id))
result['configurations'] = output.get('configurations')
except CLIInternalError:
pass
Expand All @@ -77,7 +77,7 @@ def transform_local_linker_properties(result):
result = todict(result)
resource_id = result.get('id')
try:
output = run_cli_cmd('az connection generate-configuration --id {} -o json'.format(resource_id))
output = run_cli_cmd('az connection generate-configuration --id "{}" -o json'.format(resource_id))
result['configurations'] = output.get('configurations')
except CLIInternalError:
pass
Expand Down
25 changes: 15 additions & 10 deletions src/azure-cli/azure/cli/command_modules/serviceconnector/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import time
from knack.log import get_logger
from knack.util import todict, CLIError
from msrestazure.tools import parse_resource_id
from azure.cli.core.azclierror import (
ValidationError,
CLIInternalError
Expand All @@ -18,11 +17,20 @@
TARGET_RESOURCES_USERTOKEN,
RESOURCE
)

from azure.mgmt.core.tools import (
parse_resource_id,
is_valid_resource_id as is_valid_resource_id_sdk
)

logger = get_logger(__name__)


def is_valid_resource_id(value):
if re.search('[\"\'|]', value):
return False
return is_valid_resource_id_sdk(value)


def should_load_source(source):
'''Check whether to load `az {source} connection`
If {source} is an extension (e.g, spring-cloud), load the command group only when {source} is installed
Expand Down Expand Up @@ -124,7 +132,7 @@ def provider_is_registered(subscription=None):
# register the provider
subs_arg = ''
if subscription:
subs_arg = '--subscription {}'.format(subscription)
subs_arg = '--subscription "{}"'.format(subscription)
output = run_cli_cmd(
'az provider show -n Microsoft.ServiceLinker {}'.format(subs_arg))
if output.get('registrationState') == 'NotRegistered':
Expand All @@ -138,7 +146,7 @@ def register_provider(subscription=None):

subs_arg = ''
if subscription:
subs_arg = '--subscription {}'.format(subscription)
subs_arg = '--subscription "{}"'.format(subscription)

# register the provider
run_cli_cmd(
Expand Down Expand Up @@ -275,13 +283,10 @@ def get_auth_if_no_valid_key_vault_connection(source_name, source_id, key_vault_

# https://docs.microsoft.com/azure/app-service/app-service-key-vault-references
def get_auth_if_no_valid_key_vault_connection_for_webapp(source_id, key_vault_connections):
from msrestazure.tools import (
is_valid_resource_id
)

try:
webapp = run_cli_cmd(
'az rest -u {}?api-version=2020-09-01 -o json'.format(source_id))
'az rest -u "{}?api-version=2020-09-01" -o json'.format(source_id))
reference_identity = webapp.get(
'properties').get('keyVaultReferenceIdentity')
except Exception as e:
Expand All @@ -299,7 +304,7 @@ def get_auth_if_no_valid_key_vault_connection_for_webapp(source_id, key_vault_co
except Exception: # pylint: disable=broad-except
try:
identity = run_cli_cmd(
'az identity show --ids {} -o json'.format(reference_identity))
'az identity show --ids "{}" -o json'.format(reference_identity))
client_id = identity.get('clientId')
except Exception: # pylint: disable=broad-except
pass
Expand Down Expand Up @@ -405,7 +410,7 @@ def get_object_id_of_current_user():
return user_object_id
if user_type == 'servicePrincipal':
user_info = run_cli_cmd(
f'az ad sp show --id {signed_in_user.get("name")} -o json')
f'az ad sp show --id "{signed_in_user.get("name")}" -o json')
user_object_id = user_info.get('id')
return user_object_id
except CLIInternalError as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
prompt,
prompt_pass
)
from msrestazure.tools import (
from azure.mgmt.core.tools import (
parse_resource_id,
is_valid_resource_id
)
from azure.cli.core import telemetry
from azure.cli.core.commands.client_factory import get_subscription_id
Expand All @@ -26,7 +25,8 @@

from ._utils import (
run_cli_cmd,
get_object_id_of_current_user
get_object_id_of_current_user,
is_valid_resource_id
)
from ._resource_config import (
AUTH_TYPE,
Expand Down Expand Up @@ -148,7 +148,7 @@ def _infer_webapp(source_id):

client_type = None
try:
output = run_cli_cmd('az webapp show --id {} -o json'.format(source_id))
output = run_cli_cmd('az webapp show --id "{}" -o json'.format(source_id))
prop = output.get('siteConfig').get('linuxFxVersion', None) or\
output.get('siteConfig').get('windowsFxVersion', None)
# use 'linuxFxVersion' and 'windowsFxVersion' property to decide
Expand All @@ -170,7 +170,7 @@ def _infer_springcloud(source_id):
client_type = CLIENT_TYPE.SpringBoot
try:
segments = parse_resource_id(source_id)
output = run_cli_cmd('az spring app show -g {} -s {} -n {}'
output = run_cli_cmd('az spring app show -g "{}" -s "{}" -n "{}"'
' -o json'.format(segments.get('resource_group'), segments.get('name'),
segments.get('child_name_1')))
prop_val = output.get('properties')\
Expand Down Expand Up @@ -491,8 +491,9 @@ def get_missing_target_args(cmd):
target = get_target_resource_name(cmd)
missing_args = dict()

for arg, content in TARGET_RESOURCES_PARAMS.get(target).items():
missing_args[arg] = content
if target in TARGET_RESOURCES_PARAMS:
for arg, content in TARGET_RESOURCES_PARAMS.get(target).items():
missing_args[arg] = content

return missing_args

Expand All @@ -512,6 +513,8 @@ def get_missing_auth_args(cmd, namespace):
auth_param_exist = True
break

if target == RESOURCE.ConfluentKafka:
return missing_args
# when keyvault csi is enabled, auth_type is userIdentity without subs_id and client_id
if source == RESOURCE.KubernetesCluster and target == RESOURCE.KeyVault:
if getattr(namespace, 'enable_csi', None):
Expand Down Expand Up @@ -662,7 +665,7 @@ def validate_update_params(cmd, namespace):
'''Get missing args of update command
'''
missing_args = dict()
if not validate_connection_id(namespace):
if not validate_connection_id(namespace) and not validate_source_resource_id(cmd, namespace):
missing_args.update(get_missing_source_args(cmd, namespace))
# missing_args.update(get_missing_auth_args(cmd, namespace))
missing_args.update(get_missing_connection_name(namespace))
Expand Down Expand Up @@ -770,7 +773,7 @@ def apply_auth_args(cmd, namespace, arg_values):


def apply_workload_identity(namespace, arg_values):
output = run_cli_cmd('az identity show --ids {}'.format(
output = run_cli_cmd('az identity show --ids "{}"'.format(
arg_values.get('workload_identity_auth_info')
))
if output:
Expand Down Expand Up @@ -905,7 +908,7 @@ def _validate_and_apply(validate, apply):
namespace.connection_name = generate_connection_name(cmd)
else:
validate_connection_name(namespace.connection_name)
if getattr(namespace, 'new_addon'):
if getattr(namespace, 'new_addon', None):
_validate_and_apply(validate_addon_params, apply_addon_params)
else:
_validate_and_apply(validate_create_params, apply_create_params)
Expand Down Expand Up @@ -957,7 +960,7 @@ def validate_service_state(linker_parameters):
if not rg or not name:
return

output = run_cli_cmd('az appconfig show -g {} -n {}'.format(rg, name))
output = run_cli_cmd('az appconfig show -g "{}" -n "{}"'.format(rg, name))
if output and output.get('disableLocalAuth') is True:
raise ValidationError('Secret as auth type is not allowed when local auth is disabled for the '
'specified appconfig, you may use service principal or managed identity.')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def get_action(self, values, option_string, command_name): # pylint: disable=no
'Required keys are: client-id, secret')
if 'principal_id' not in d:
from ._utils import run_cli_cmd
output = run_cli_cmd('az ad sp show --id {}'.format(d['client_id']))
output = run_cli_cmd('az ad sp show --id "{}"'.format(d['client_id']))
if output:
d['principal_id'] = output.get('id')
else:
Expand Down Expand Up @@ -258,7 +258,7 @@ def get_action(self, values, option_string): # pylint: disable=no-self-use
d = {}
if 'user-identity-resource-id' in properties:
from ._utils import run_cli_cmd
output = run_cli_cmd('az identity show --ids {}'.format(properties['user-identity-resource-id']))
output = run_cli_cmd('az identity show --ids "{}"'.format(properties['user-identity-resource-id']))
if output:
d['client_id'] = output.get('clientId')
d['subscription_id'] = properties['user-identity-resource-id'].split('/')[2]
Expand Down

0 comments on commit 6213de3

Please sign in to comment.