diff --git a/dags/inference/configs/trt_llm_inference_config.py b/dags/inference/configs/trt_llm_inference_config.py index c08a9132..8bdfeb67 100644 --- a/dags/inference/configs/trt_llm_inference_config.py +++ b/dags/inference/configs/trt_llm_inference_config.py @@ -36,6 +36,7 @@ def get_trt_llm_gpu_config( project: Project, network: str, subnetwork: str, + existing_instance_name: str = None, ) -> task.GpuCreateResourceTask: set_up_cmds = ( "pip install --upgrade pip", @@ -141,6 +142,7 @@ def get_trt_llm_gpu_config( timeout=datetime.timedelta(minutes=time_out_in_min), task_owner=test_owner.YIJIA_J, gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/trt_llm", + use_existing_instance=existing_instance_name is not None, ) job_gcp_config = gcp_config.GCPConfig( @@ -160,4 +162,5 @@ def get_trt_llm_gpu_config( job_test_config, job_gcp_config, job_metric_config, + existing_instance_name=existing_instance_name, ) diff --git a/dags/inference/configs/trt_llm_mlperf_v40_config.py b/dags/inference/configs/trt_llm_mlperf_v40_config.py index d50b000a..3844c854 100644 --- a/dags/inference/configs/trt_llm_mlperf_v40_config.py +++ b/dags/inference/configs/trt_llm_mlperf_v40_config.py @@ -37,6 +37,7 @@ def get_trt_llm_mlperf_v40_gpu_config( project: Project, network: str, subnetwork: str, + existing_instance_name: str = None, model_configs: Dict = {}, ) -> task.GpuCreateResourceTask: docker_container_name = "mlperf-inference" @@ -109,6 +110,7 @@ def get_trt_llm_mlperf_v40_gpu_config( docker_cmd = " && ".join(docker_cmds) run_model_cmds = ( "pip install jsonlines", + f"docker restart {docker_container_name}", f'docker exec -i {docker_container_name} /bin/bash -c "{docker_cmd}"', make_jsonl_converter_cmd, "cat jsonl_converter.py", @@ -133,6 +135,7 @@ def get_trt_llm_mlperf_v40_gpu_config( timeout=datetime.timedelta(minutes=time_out_in_min), task_owner=test_owner.YIJIA_J, gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/trt_llm_mlperf_v40", + use_existing_instance=existing_instance_name is not None, ) job_gcp_config = gcp_config.GCPConfig( @@ -152,4 +155,5 @@ def get_trt_llm_mlperf_v40_gpu_config( job_test_config, job_gcp_config, job_metric_config, + existing_instance_name=existing_instance_name, ) diff --git a/dags/inference/configs/trt_llm_mlperf_v41_config.py b/dags/inference/configs/trt_llm_mlperf_v41_config.py index 7371739a..e75dccfa 100644 --- a/dags/inference/configs/trt_llm_mlperf_v41_config.py +++ b/dags/inference/configs/trt_llm_mlperf_v41_config.py @@ -37,6 +37,7 @@ def get_trt_llm_mlperf_gpu_config( project: Project, network: str, subnetwork: str, + existing_instance_name: str = None, benchmark_configs: Dict = {}, model_parameters: Dict = {}, parameter_positions: Dict = {}, @@ -196,6 +197,7 @@ def get_trt_llm_mlperf_gpu_config( timeout=datetime.timedelta(minutes=time_out_in_min), task_owner=test_owner.YIJIA_J, gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/trt_llm_mlperf_v41", + use_existing_instance=existing_instance_name is not None, ) job_gcp_config = gcp_config.GCPConfig( @@ -215,4 +217,5 @@ def get_trt_llm_mlperf_gpu_config( job_test_config, job_gcp_config, job_metric_config, + existing_instance_name=existing_instance_name, ) diff --git a/dags/inference/trt_llm_mlperf_v40_inference.py b/dags/inference/trt_llm_mlperf_v40_inference.py index 53db12a9..92f05610 100644 --- a/dags/inference/trt_llm_mlperf_v40_inference.py +++ b/dags/inference/trt_llm_mlperf_v40_inference.py @@ -61,4 +61,5 @@ network=INFERENCE_NETWORKS, subnetwork=H100_INFERENCE_SUBNETWORKS, model_configs=model_configs, + existing_instance_name="yijiaj-test-h100x8", ).run() diff --git a/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py b/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py index 24b29b4b..00824db5 100644 --- a/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py +++ b/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py @@ -475,6 +475,7 @@ def get_torchbench_gpu_config( timeout=datetime.timedelta(minutes=time_out_in_min), task_owner=test_owner.PEI_Z, gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/torchbench", + use_existing_instance=False, ) job_metric_config = metric_config.MetricConfig( diff --git a/dags/solutions_team/configs/vllm/vllm_benchmark_config.py b/dags/solutions_team/configs/vllm/vllm_benchmark_config.py index 5623428f..4d1d46f9 100644 --- a/dags/solutions_team/configs/vllm/vllm_benchmark_config.py +++ b/dags/solutions_team/configs/vllm/vllm_benchmark_config.py @@ -186,6 +186,7 @@ def get_gpu_vllm_gce_config( timeout=datetime.timedelta(minutes=time_out_in_min), task_owner=test_owner.RICHARD_L, gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/vllm_benchmark", + use_existing_instance=False, ) job_gcp_config = gcp_config.GCPConfig( diff --git a/xlml/apis/task.py b/xlml/apis/task.py index 5e721cee..40f0443b 100644 --- a/xlml/apis/task.py +++ b/xlml/apis/task.py @@ -349,6 +349,7 @@ class GpuCreateResourceTask(BaseTask): task_metric_config: metric configuration (e.g., result gcs path). gpu_create_timeout: timeout when waiting for the GPU vm creation. install_nvidia_drivers: whether to install Nvidia drivers. + existing_instance_name: whether an existing GPU instance shall be used. """ image_project: str @@ -358,6 +359,7 @@ class GpuCreateResourceTask(BaseTask): task_metric_config: Optional[metric_config.MetricConfig] = None gpu_create_timeout: datetime.timedelta = datetime.timedelta(minutes=60) install_nvidia_drivers: bool = False + existing_instance_name: str = None def run(self) -> DAGNode: """Run a test job. @@ -368,6 +370,9 @@ def run(self) -> DAGNode: """ # piz: We skip the queued resource for GPU for now since there is no queued # resource command for GPU. + if self.existing_instance_name is not None: + return self.run_with_existing_instance() + with TaskGroup( group_id=self.task_test_config.benchmark_id, prefix_group_id=True ) as group: @@ -399,6 +404,58 @@ def run(self) -> DAGNode: provision >> run_model >> post_process >> clean_up return group + def run_with_existing_instance(self) -> DAGNode: + """Run a test job via existing instance. + + Returns: + A task group with the following tasks chained: provision, run_model and post_process, clean_up. + """ + with TaskGroup( + group_id=self.task_test_config.benchmark_id, prefix_group_id=True + ) as group: + ( + provision, + ip_address, + ssh_keys, + gcs_location, + ) = self.provision_via_existing_instance() + if ( + self.task_metric_config + and self.task_metric_config.use_runtime_generated_gcs_folder + ): + env_variable = { + f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location + } + else: + env_variable = None + post_process = self.post_process(gcs_location) + run_model = self.run_model(ip_address, ssh_keys, env_variable) + clean_up = self.clean_up_existing_instance(ssh_keys) + provision >> run_model >> post_process >> clean_up + return group + + def provision_via_existing_instance( + self, + ) -> Tuple[DAGNode, airflow.XComArg, airflow.XComArg, airflow.XComArg,]: + """Provision an existing GPU accelerator. + + Returns: + A DAG node that will provision a GPU, an XCome value of the ip address + for the host,an XCom value for the SSH keys. + """ + with TaskGroup(group_id="provision") as group: + ssh_keys = ssh.generate_ssh_keys() + ip_address = gpu.get_existing_resource( + instance_name=self.existing_instance_name, + ssh_keys=ssh_keys, + gcp=self.task_gcp_config, + ) + gcs_location = name_format.generate_gcs_folder_location( + self.task_test_config.gcs_subfolder, + self.task_test_config.benchmark_id, + ) + return group, ip_address, ssh_keys, gcs_location + def provision( self, ) -> Tuple[ @@ -476,7 +533,8 @@ def run_model( ) def post_process( - self, result_location: Optional[airflow.XComArg] = None + self, + result_location: Optional[airflow.XComArg] = None, ) -> DAGNode: """Process metrics and metadata, and insert them into BigQuery tables. @@ -513,6 +571,20 @@ def clean_up( resource, project_id, zone ) + def clean_up_existing_instance(self, ssh_keys: airflow.XComArg) -> DAGNode: + """Clean up existing GPU resources - remove the one-time use generated ssh_keys. + + Args: + ssh_keys: generated GPU's one-time use SSH keys to be removed. + Returns: + A DAG node that cleaned up the ssh_keys. + """ + return gpu.clean_up_ssh_keys( + instance_name=self.existing_instance_name, + ssh_keys=ssh_keys, + gcp=self.task_gcp_config, + ) + # TODO(ranran): This class is big. Let's move it to a new file. @dataclasses.dataclass diff --git a/xlml/apis/test_config.py b/xlml/apis/test_config.py index e6d724cc..1704b95a 100644 --- a/xlml/apis/test_config.py +++ b/xlml/apis/test_config.py @@ -224,11 +224,13 @@ class GpuVmTest(TestConfig[Gpu]): test_name: Unique name for this test/model. set_up_cmds: List of commands to run once when GPU is created. run_model_cmds: List of commands to run the model under test. + use_existing_instance: Whether to use an existing GPU instance. """ test_name: str set_up_cmds: Iterable[str] run_model_cmds: Iterable[str] + use_existing_instance: bool @property def benchmark_id(self) -> str: diff --git a/xlml/utils/gpu.py b/xlml/utils/gpu.py index 877c6059..e19a82d7 100644 --- a/xlml/utils/gpu.py +++ b/xlml/utils/gpu.py @@ -129,6 +129,128 @@ def generate_gpu_name() -> str: return f"gpu-{str(uuid.uuid4())}" +@task +def get_existing_resource( + instance_name: str, + ssh_keys: ssh.SshKeys, + gcp: gcp_config.GCPConfig, +) -> airflow.XComArg: + """Reach a resource node that is already created. + + Args: + instance_name: name of the existing instance. + ssh_keys: airflow.XComArg, + gcp: GCP project/zone configuration. + + Returns: + The ip address of the GPU VM. + """ + instance_client = compute_v1.InstancesClient() + instance_request = compute_v1.GetInstanceRequest( + instance=instance_name, + project=gcp.project_name, + zone=gcp.zone, + ) + instance = instance_client.get(request=instance_request) + logging.info( + f"Resource retrieve status: {instance.status}, {instance.status_message}" + ) + + ip_address = instance.network_interfaces[0].network_i_p + metadata = instance.metadata + items = metadata.items or [] + ssh_key_exist = False + for item in metadata.items: + if item.key == "ssh-keys": + ssh_key_exist = True + item.value = ( + item.value + "\n" + f"cloud-ml-auto-solutions:{ssh_keys.public}" + ) + break + if not ssh_key_exist: + items.append({ + "key": "ssh-keys", + "value": f"cloud-ml-auto-solutions:{ssh_keys.public}", + }) + metadata.items = items + metadata_request = compute_v1.SetMetadataInstanceRequest( + instance=instance_name, + project=gcp.project_name, + zone=gcp.zone, + metadata_resource=metadata, + ) + operation = instance_client.set_metadata(request=metadata_request) + if operation.error: + logging.error( + ( + "Error during instance set metadata: [Code:" + f" {operation.http_error_status_code}]:" + f" {operation.http_error_message}" + f" {operation.error}" + ), + ) + raise operation.exception() or RuntimeError(operation.http_error_message) + elif operation.warnings: + logging.warning("Warnings during instance set metadata:\n") + for warning in operation.warnings: + logging.warning(f" - {warning.code}: {warning.message}") + + return ip_address + + +@task(trigger_rule="all_done") +def clean_up_ssh_keys( + instance_name: str, + ssh_keys: ssh.SshKeys, + gcp: gcp_config.GCPConfig, +) -> airflow.XComArg: + """Remove the generated one-time use ssh_keys from existing instance. + + Args: + instance_name: name of the existing instance. + ssh_keys: airflow.XComArg, + gcp: GCP project/zone configuration. + """ + instance_client = compute_v1.InstancesClient() + instance_request = compute_v1.GetInstanceRequest( + instance=instance_name, + project=gcp.project_name, + zone=gcp.zone, + ) + instance = instance_client.get(request=instance_request) + logging.info( + f"Resource get status: {instance.status}, {instance.status_message}" + ) + metadata = instance.metadata + for item in metadata.items: + if item.key == "ssh-keys": + item.value = item.value.replace( + f"\ncloud-ml-auto-solutions:{ssh_keys.public}", "" + ) + break + metadata_request = compute_v1.SetMetadataInstanceRequest( + instance=instance_name, + project=gcp.project_name, + zone=gcp.zone, + metadata_resource=metadata, + ) + operation = instance_client.set_metadata(request=metadata_request) + if operation.error: + logging.error( + ( + "Error during instance set metadata: [Code:" + f" {operation.http_error_status_code}]:" + f" {operation.http_error_message}" + f" {operation.error}" + ), + ) + raise operation.exception() or RuntimeError(operation.http_error_message) + elif operation.warnings: + logging.warning("Warnings during instance set metadata:\n") + for warning in operation.warnings: + logging.warning(f" - {warning.code}: {warning.message}") + + @task_group def create_resource( gpu_name: airflow.XComArg, diff --git a/xlml/utils/metric.py b/xlml/utils/metric.py index 06879511..26010953 100644 --- a/xlml/utils/metric.py +++ b/xlml/utils/metric.py @@ -614,9 +614,14 @@ def get_gce_job_status( task_id=f"{benchmark_id}.provision.create_queued_resource.wait_for_ready_queued_resource" ) elif isinstance(task_test_config, test_config.GpuVmTest): - wait_task = current_dag.get_task( - task_id=f"{benchmark_id}.provision.create_resource.get_ip_address" - ) + if task_test_config.use_existing_instance: + wait_task = current_dag.get_task( + task_id=f"{benchmark_id}.provision.get_existing_resource" + ) + else: + wait_task = current_dag.get_task( + task_id=f"{benchmark_id}.provision.create_resource.get_ip_address" + ) else: raise NotImplementedError( f"Unable to get task for {type(task_test_config.accelerator)}." @@ -631,12 +636,29 @@ def get_gce_job_status( return bigquery.JobStatus.MISSED # check setup status to see if setup step is successful - setup_task = current_dag.get_task(task_id=f"{benchmark_id}.provision.setup") - setup_ti = TaskInstance(setup_task, execution_date) - setup_state = setup_ti.current_state() - if setup_state == TaskState.FAILED.value: - logging.info("The setup state is failed, and the job status is failed.") - return bigquery.JobStatus.FAILED + if ( + hasattr(task_test_config, "use_existing_instance") + and task_test_config.use_existing_instance + ): + get_instance_task = current_dag.get_task( + task_id=f"{benchmark_id}.provision.get_existing_resource" + ) + get_instance_ti = TaskInstance(get_instance_task, execution_date) + get_instance_state = get_instance_ti.current_state() + if get_instance_state == TaskState.FAILED.value: + logging.info( + "The getting existing instance state is failed, and the job status is failed." + ) + return bigquery.JobStatus.FAILED + else: + setup_task = current_dag.get_task( + task_id=f"{benchmark_id}.provision.setup" + ) + setup_ti = TaskInstance(setup_task, execution_date) + setup_state = setup_ti.current_state() + if setup_state == TaskState.FAILED.value: + logging.info("The setup state is failed, and the job status is failed.") + return bigquery.JobStatus.FAILED # check run_model status to see if run_model step is successful run_model_task = current_dag.get_task(task_id=f"{benchmark_id}.run_model")