Skip to content

Commit

Permalink
Reuse Exist GPU Instance for Auotmation (#538)
Browse files Browse the repository at this point in the history
* use existed instance

* Use existed instance and test on A3

* format

* nit

* nit

* format

* nit and wrap
  • Loading branch information
jyj0w0 authored Jan 16, 2025
1 parent cbb3a97 commit 5d6ff02
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 10 deletions.
3 changes: 3 additions & 0 deletions dags/inference/configs/trt_llm_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_trt_llm_gpu_config(
project: Project,
network: str,
subnetwork: str,
existing_instance_name: str = None,
) -> task.GpuCreateResourceTask:
set_up_cmds = (
"pip install --upgrade pip",
Expand Down Expand Up @@ -141,6 +142,7 @@ def get_trt_llm_gpu_config(
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.YIJIA_J,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/trt_llm",
use_existing_instance=existing_instance_name is not None,
)

job_gcp_config = gcp_config.GCPConfig(
Expand All @@ -160,4 +162,5 @@ def get_trt_llm_gpu_config(
job_test_config,
job_gcp_config,
job_metric_config,
existing_instance_name=existing_instance_name,
)
4 changes: 4 additions & 0 deletions dags/inference/configs/trt_llm_mlperf_v40_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_trt_llm_mlperf_v40_gpu_config(
project: Project,
network: str,
subnetwork: str,
existing_instance_name: str = None,
model_configs: Dict = {},
) -> task.GpuCreateResourceTask:
docker_container_name = "mlperf-inference"
Expand Down Expand Up @@ -109,6 +110,7 @@ def get_trt_llm_mlperf_v40_gpu_config(
docker_cmd = " && ".join(docker_cmds)
run_model_cmds = (
"pip install jsonlines",
f"docker restart {docker_container_name}",
f'docker exec -i {docker_container_name} /bin/bash -c "{docker_cmd}"',
make_jsonl_converter_cmd,
"cat jsonl_converter.py",
Expand All @@ -133,6 +135,7 @@ def get_trt_llm_mlperf_v40_gpu_config(
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.YIJIA_J,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/trt_llm_mlperf_v40",
use_existing_instance=existing_instance_name is not None,
)

job_gcp_config = gcp_config.GCPConfig(
Expand All @@ -152,4 +155,5 @@ def get_trt_llm_mlperf_v40_gpu_config(
job_test_config,
job_gcp_config,
job_metric_config,
existing_instance_name=existing_instance_name,
)
3 changes: 3 additions & 0 deletions dags/inference/configs/trt_llm_mlperf_v41_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_trt_llm_mlperf_gpu_config(
project: Project,
network: str,
subnetwork: str,
existing_instance_name: str = None,
benchmark_configs: Dict = {},
model_parameters: Dict = {},
parameter_positions: Dict = {},
Expand Down Expand Up @@ -196,6 +197,7 @@ def get_trt_llm_mlperf_gpu_config(
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.YIJIA_J,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/trt_llm_mlperf_v41",
use_existing_instance=existing_instance_name is not None,
)

job_gcp_config = gcp_config.GCPConfig(
Expand All @@ -215,4 +217,5 @@ def get_trt_llm_mlperf_gpu_config(
job_test_config,
job_gcp_config,
job_metric_config,
existing_instance_name=existing_instance_name,
)
1 change: 1 addition & 0 deletions dags/inference/trt_llm_mlperf_v40_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@
network=INFERENCE_NETWORKS,
subnetwork=H100_INFERENCE_SUBNETWORKS,
model_configs=model_configs,
existing_instance_name="yijiaj-test-h100x8",
).run()
1 change: 1 addition & 0 deletions dags/pytorch_xla/configs/pytorchxla_torchbench_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def get_torchbench_gpu_config(
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.PEI_Z,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/torchbench",
use_existing_instance=False,
)

job_metric_config = metric_config.MetricConfig(
Expand Down
1 change: 1 addition & 0 deletions dags/solutions_team/configs/vllm/vllm_benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def get_gpu_vllm_gce_config(
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.RICHARD_L,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/vllm_benchmark",
use_existing_instance=False,
)

job_gcp_config = gcp_config.GCPConfig(
Expand Down
74 changes: 73 additions & 1 deletion xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ class GpuCreateResourceTask(BaseTask):
task_metric_config: metric configuration (e.g., result gcs path).
gpu_create_timeout: timeout when waiting for the GPU vm creation.
install_nvidia_drivers: whether to install Nvidia drivers.
existing_instance_name: whether an existing GPU instance shall be used.
"""

image_project: str
Expand All @@ -358,6 +359,7 @@ class GpuCreateResourceTask(BaseTask):
task_metric_config: Optional[metric_config.MetricConfig] = None
gpu_create_timeout: datetime.timedelta = datetime.timedelta(minutes=60)
install_nvidia_drivers: bool = False
existing_instance_name: str = None

def run(self) -> DAGNode:
"""Run a test job.
Expand All @@ -368,6 +370,9 @@ def run(self) -> DAGNode:
"""
# piz: We skip the queued resource for GPU for now since there is no queued
# resource command for GPU.
if self.existing_instance_name is not None:
return self.run_with_existing_instance()

with TaskGroup(
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
) as group:
Expand Down Expand Up @@ -399,6 +404,58 @@ def run(self) -> DAGNode:
provision >> run_model >> post_process >> clean_up
return group

def run_with_existing_instance(self) -> DAGNode:
"""Run a test job via existing instance.
Returns:
A task group with the following tasks chained: provision, run_model and post_process, clean_up.
"""
with TaskGroup(
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
) as group:
(
provision,
ip_address,
ssh_keys,
gcs_location,
) = self.provision_via_existing_instance()
if (
self.task_metric_config
and self.task_metric_config.use_runtime_generated_gcs_folder
):
env_variable = {
f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location
}
else:
env_variable = None
post_process = self.post_process(gcs_location)
run_model = self.run_model(ip_address, ssh_keys, env_variable)
clean_up = self.clean_up_existing_instance(ssh_keys)
provision >> run_model >> post_process >> clean_up
return group

def provision_via_existing_instance(
self,
) -> Tuple[DAGNode, airflow.XComArg, airflow.XComArg, airflow.XComArg,]:
"""Provision an existing GPU accelerator.
Returns:
A DAG node that will provision a GPU, an XCome value of the ip address
for the host,an XCom value for the SSH keys.
"""
with TaskGroup(group_id="provision") as group:
ssh_keys = ssh.generate_ssh_keys()
ip_address = gpu.get_existing_resource(
instance_name=self.existing_instance_name,
ssh_keys=ssh_keys,
gcp=self.task_gcp_config,
)
gcs_location = name_format.generate_gcs_folder_location(
self.task_test_config.gcs_subfolder,
self.task_test_config.benchmark_id,
)
return group, ip_address, ssh_keys, gcs_location

def provision(
self,
) -> Tuple[
Expand Down Expand Up @@ -476,7 +533,8 @@ def run_model(
)

def post_process(
self, result_location: Optional[airflow.XComArg] = None
self,
result_location: Optional[airflow.XComArg] = None,
) -> DAGNode:
"""Process metrics and metadata, and insert them into BigQuery tables.
Expand Down Expand Up @@ -513,6 +571,20 @@ def clean_up(
resource, project_id, zone
)

def clean_up_existing_instance(self, ssh_keys: airflow.XComArg) -> DAGNode:
"""Clean up existing GPU resources - remove the one-time use generated ssh_keys.
Args:
ssh_keys: generated GPU's one-time use SSH keys to be removed.
Returns:
A DAG node that cleaned up the ssh_keys.
"""
return gpu.clean_up_ssh_keys(
instance_name=self.existing_instance_name,
ssh_keys=ssh_keys,
gcp=self.task_gcp_config,
)


# TODO(ranran): This class is big. Let's move it to a new file.
@dataclasses.dataclass
Expand Down
2 changes: 2 additions & 0 deletions xlml/apis/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,13 @@ class GpuVmTest(TestConfig[Gpu]):
test_name: Unique name for this test/model.
set_up_cmds: List of commands to run once when GPU is created.
run_model_cmds: List of commands to run the model under test.
use_existing_instance: Whether to use an existing GPU instance.
"""

test_name: str
set_up_cmds: Iterable[str]
run_model_cmds: Iterable[str]
use_existing_instance: bool

@property
def benchmark_id(self) -> str:
Expand Down
122 changes: 122 additions & 0 deletions xlml/utils/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,128 @@ def generate_gpu_name() -> str:
return f"gpu-{str(uuid.uuid4())}"


@task
def get_existing_resource(
instance_name: str,
ssh_keys: ssh.SshKeys,
gcp: gcp_config.GCPConfig,
) -> airflow.XComArg:
"""Reach a resource node that is already created.
Args:
instance_name: name of the existing instance.
ssh_keys: airflow.XComArg,
gcp: GCP project/zone configuration.
Returns:
The ip address of the GPU VM.
"""
instance_client = compute_v1.InstancesClient()
instance_request = compute_v1.GetInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
)
instance = instance_client.get(request=instance_request)
logging.info(
f"Resource retrieve status: {instance.status}, {instance.status_message}"
)

ip_address = instance.network_interfaces[0].network_i_p
metadata = instance.metadata
items = metadata.items or []
ssh_key_exist = False
for item in metadata.items:
if item.key == "ssh-keys":
ssh_key_exist = True
item.value = (
item.value + "\n" + f"cloud-ml-auto-solutions:{ssh_keys.public}"
)
break
if not ssh_key_exist:
items.append({
"key": "ssh-keys",
"value": f"cloud-ml-auto-solutions:{ssh_keys.public}",
})
metadata.items = items
metadata_request = compute_v1.SetMetadataInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
metadata_resource=metadata,
)
operation = instance_client.set_metadata(request=metadata_request)
if operation.error:
logging.error(
(
"Error during instance set metadata: [Code:"
f" {operation.http_error_status_code}]:"
f" {operation.http_error_message}"
f" {operation.error}"
),
)
raise operation.exception() or RuntimeError(operation.http_error_message)
elif operation.warnings:
logging.warning("Warnings during instance set metadata:\n")
for warning in operation.warnings:
logging.warning(f" - {warning.code}: {warning.message}")

return ip_address


@task(trigger_rule="all_done")
def clean_up_ssh_keys(
instance_name: str,
ssh_keys: ssh.SshKeys,
gcp: gcp_config.GCPConfig,
) -> airflow.XComArg:
"""Remove the generated one-time use ssh_keys from existing instance.
Args:
instance_name: name of the existing instance.
ssh_keys: airflow.XComArg,
gcp: GCP project/zone configuration.
"""
instance_client = compute_v1.InstancesClient()
instance_request = compute_v1.GetInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
)
instance = instance_client.get(request=instance_request)
logging.info(
f"Resource get status: {instance.status}, {instance.status_message}"
)
metadata = instance.metadata
for item in metadata.items:
if item.key == "ssh-keys":
item.value = item.value.replace(
f"\ncloud-ml-auto-solutions:{ssh_keys.public}", ""
)
break
metadata_request = compute_v1.SetMetadataInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
metadata_resource=metadata,
)
operation = instance_client.set_metadata(request=metadata_request)
if operation.error:
logging.error(
(
"Error during instance set metadata: [Code:"
f" {operation.http_error_status_code}]:"
f" {operation.http_error_message}"
f" {operation.error}"
),
)
raise operation.exception() or RuntimeError(operation.http_error_message)
elif operation.warnings:
logging.warning("Warnings during instance set metadata:\n")
for warning in operation.warnings:
logging.warning(f" - {warning.code}: {warning.message}")


@task_group
def create_resource(
gpu_name: airflow.XComArg,
Expand Down
Loading

0 comments on commit 5d6ff02

Please sign in to comment.