From 9caf7a9b80d3aa5b0aea16d9b40d98fbbf796e1e Mon Sep 17 00:00:00 2001 From: Tejaswi Kandula Date: Wed, 21 Aug 2024 09:38:55 -0700 Subject: [PATCH] Remove fetching resourcehostname on client side --- src/bastion/azext_bastion/custom.py | 56 ++++++++--------------------- 1 file changed, 14 insertions(+), 42 deletions(-) diff --git a/src/bastion/azext_bastion/custom.py b/src/bastion/azext_bastion/custom.py index 48ca6076fa4..12c0327ec5b 100644 --- a/src/bastion/azext_bastion/custom.py +++ b/src/bastion/azext_bastion/custom.py @@ -25,10 +25,6 @@ from msrestazure.tools import is_valid_resource_id from .BastionServiceConstants import BastionSku from .aaz.latest.network.bastion import Create as _BastionCreate -from azure.identity import AzureCliCredential -from azure.mgmt.resourcegraph import ResourceGraphClient -from azure.mgmt.compute import ComputeManagementClient -from azure.mgmt.resourcegraph.models import QueryRequest logger = get_logger(__name__) @@ -235,36 +231,20 @@ def _get_rdp_path(rdp_command="mstsc"): return rdp_path -def get_host_name(cmd, target_resource_id, resource_group_name): - credential = AzureCliCredential() - subscription_id = get_subscription_id(cmd.cli_ctx) - resource_graph_client = ResourceGraphClient(credential) - - # Query to get the VM resource - query = f""" - Resources - | where type == 'microsoft.compute/virtualmachines' - | where tolower(id) == tolower('{target_resource_id}') - | project id, name, resourceGroup - """ - query_request = QueryRequest( - subscriptions=[subscription_id], - query=query +def _generate_rdp_file(port): + import os + + rdp_file_content = ( + f"full address:s:localhost:{port}\n" + f"alternate full address:s:localhost:{port}\n" + "use multimon:i:1\n" ) - query_response = resource_graph_client.resources(query_request) - vm_name = None - hostname = None - for result in query_response.data: - if result['id'].lower() == target_resource_id.lower() and result['resourceGroup'].lower() == resource_group_name.lower(): - vm_name = result['name'] - break - if vm_name: - compute_client = ComputeManagementClient(credential, subscription_id) - vm_instance = compute_client.virtual_machines.get(resource_group_name, vm_name) - hostname = vm_instance.os_profile.computer_name + rdpfilepath = os.path.join(tempfile.gettempdir(), f'conn_{uuid.uuid4().hex}.rdp') + with open(rdpfilepath, 'w') as rdp_file: + rdp_file.write(rdp_file_content) - return hostname + return rdpfilepath def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name, @@ -323,15 +303,7 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ t.daemon = True t.start() - rdp_file_content = ( - f"full address:s:localhost:{tunnel_server.local_port}\n" - f"alternate full address:s:localhost:{tunnel_server.local_port}\n" - "use multimon:i:1\n" - ) - - rdpfilepath = os.path.join(tempfile.gettempdir(), f'conn_{uuid.uuid4().hex}.rdp') - with open(rdpfilepath, 'w') as rdp_file: - rdp_file.write(rdp_file_content) + rdpfilepath = _generate_rdp_file(tunnel_server.local_port) command = [_get_rdp_path()] if configure: @@ -340,10 +312,9 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ launch_and_wait(command) tunnel_server.cleanup() else: - hostname = get_host_name(cmd, target_resource_id, resource_group_name) access_token = Profile(cli_ctx=cmd.cli_ctx).get_raw_token()[0][2].get("accessToken") logger.debug("Response %s", access_token) - web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&resourceHostName={hostname}" \ + web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}" \ f"&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" headers = { @@ -370,6 +341,7 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ else: raise UnrecognizedArgumentError("Platform is not supported for this command. Supported platforms: Windows") + def _is_ipconnect_request(bastion, target_ip_address): if target_ip_address: if 'enableIpConnect' in bastion and bastion['enableIpConnect'] is True: