Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse Exist GPU Instance for Auotmation #538

Merged
merged 16 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dags/inference/configs/trt_llm_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_trt_llm_gpu_config(
project: Project,
network: str,
subnetwork: str,
existed_instance_name: str = None,
jyj0w0 marked this conversation as resolved.
Show resolved Hide resolved
) -> task.GpuCreateResourceTask:
set_up_cmds = (
"pip install --upgrade pip",
Expand Down Expand Up @@ -160,4 +161,5 @@ def get_trt_llm_gpu_config(
job_test_config,
job_gcp_config,
job_metric_config,
existed_instance_name=existed_instance_name,
)
3 changes: 3 additions & 0 deletions dags/inference/configs/trt_llm_mlperf_v40_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_trt_llm_mlperf_v40_gpu_config(
project: Project,
network: str,
subnetwork: str,
existed_instance_name: str = None,
model_configs: Dict = {},
) -> task.GpuCreateResourceTask:
docker_container_name = "mlperf-inference"
Expand Down Expand Up @@ -109,6 +110,7 @@ def get_trt_llm_mlperf_v40_gpu_config(
docker_cmd = " && ".join(docker_cmds)
run_model_cmds = (
"pip install jsonlines",
f"docker restart {docker_container_name}",
f'docker exec -i {docker_container_name} /bin/bash -c "{docker_cmd}"',
make_jsonl_converter_cmd,
"cat jsonl_converter.py",
Expand Down Expand Up @@ -152,4 +154,5 @@ def get_trt_llm_mlperf_v40_gpu_config(
job_test_config,
job_gcp_config,
job_metric_config,
existed_instance_name=existed_instance_name,
)
2 changes: 2 additions & 0 deletions dags/inference/configs/trt_llm_mlperf_v41_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_trt_llm_mlperf_gpu_config(
project: Project,
network: str,
subnetwork: str,
existed_instance_name: str = None,
benchmark_configs: Dict = {},
model_parameters: Dict = {},
parameter_positions: Dict = {},
Expand Down Expand Up @@ -215,4 +216,5 @@ def get_trt_llm_mlperf_gpu_config(
job_test_config,
job_gcp_config,
job_metric_config,
existed_instance_name=existed_instance_name,
)
1 change: 1 addition & 0 deletions dags/inference/trt_llm_mlperf_v40_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@
network=INFERENCE_NETWORKS,
subnetwork=H100_INFERENCE_SUBNETWORKS,
model_configs=model_configs,
existed_instance_name="yijiaj-test-h100x8",
).run()
79 changes: 77 additions & 2 deletions xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ class GpuCreateResourceTask(BaseTask):
task_metric_config: metric configuration (e.g., result gcs path).
gpu_create_timeout: timeout when waiting for the GPU vm creation.
install_nvidia_drivers: whether to install Nvidia drivers.
existed_instance_name: whether a exited GPU instance shall be used.
jyj0w0 marked this conversation as resolved.
Show resolved Hide resolved
"""

image_project: str
Expand All @@ -358,16 +359,21 @@ class GpuCreateResourceTask(BaseTask):
task_metric_config: Optional[metric_config.MetricConfig] = None
gpu_create_timeout: datetime.timedelta = datetime.timedelta(minutes=60)
install_nvidia_drivers: bool = False
existed_instance_name: str = None

def run(self) -> DAGNode:
"""Run a test job.

Returns:
A task group with the following tasks chained: provision, run_model,
post_process, clean_up.
post_process, clean_up if none existed instance is used, or a task
group with run_model and post_process only.
"""
# piz: We skip the queued resource for GPU for now since there is no queued
# resource command for GPU.
if self.existed_instance_name is not None:
return self.run_with_existed_instance()

with TaskGroup(
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
) as group:
Expand Down Expand Up @@ -399,6 +405,58 @@ def run(self) -> DAGNode:
provision >> run_model >> post_process >> clean_up
return group

def run_with_existed_instance(self) -> DAGNode:
"""Run a test job.

Returns:
A task group with the following tasks chained: run_model and post_process.
jyj0w0 marked this conversation as resolved.
Show resolved Hide resolved
"""
with TaskGroup(
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
) as group:
(
provision,
ip_address,
ssh_keys,
gcs_location,
) = self.provision_via_existed_instance()
if (
self.task_metric_config
and self.task_metric_config.use_runtime_generated_gcs_folder
):
env_variable = {
f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location
}
else:
env_variable = None
post_process = self.post_process(gcs_location, use_existed_instance=True)
run_model = self.run_model(ip_address, ssh_keys, env_variable)
clean_up = self.clean_up_existed_instance(ssh_keys)
provision >> run_model >> post_process >> clean_up
return group

def provision_via_existed_instance(
self,
) -> Tuple[DAGNode, airflow.XComArg, airflow.XComArg, airflow.XComArg,]:
"""Provision an existed GPU accelerator.

Returns:
A DAG node that will provision a GPU, an XCome value of the ip address
for the host,an XCom value for the SSH keys.
"""
with TaskGroup(group_id="provision") as group:
ssh_keys = ssh.generate_ssh_keys()
ip_address = gpu.get_existed_resource(
instance_name=self.existed_instance_name,
ssh_keys=ssh_keys,
gcp=self.task_gcp_config,
)
gcs_location = name_format.generate_gcs_folder_location(
self.task_test_config.gcs_subfolder,
self.task_test_config.benchmark_id,
)
return group, ip_address, ssh_keys, gcs_location

def provision(
self,
) -> Tuple[
Expand Down Expand Up @@ -476,7 +534,9 @@ def run_model(
)

def post_process(
self, result_location: Optional[airflow.XComArg] = None
self,
result_location: Optional[airflow.XComArg] = None,
use_existed_instance=False,
) -> DAGNode:
"""Process metrics and metadata, and insert them into BigQuery tables.

Expand All @@ -491,6 +551,7 @@ def post_process(
self.task_metric_config,
self.task_gcp_config,
folder_location=result_location,
use_existed_instance=use_existed_instance,
jyj0w0 marked this conversation as resolved.
Show resolved Hide resolved
)
return group

Expand All @@ -513,6 +574,20 @@ def clean_up(
resource, project_id, zone
)

def clean_up_existed_instance(self, ssh_keys: airflow.XComArg) -> DAGNode:
"""Clean up existed GPU resources - remove new generated ssh_keys.

Args:
ssh_keys: generated GPU's SSH keys to be removed.
Returns:
A DAG node that cleaned up the ssh_keys.
"""
return gpu.clean_up_ssh_keys(
instance_name=self.existed_instance_name,
ssh_keys=ssh_keys,
gcp=self.task_gcp_config,
)


# TODO(ranran): This class is big. Let's move it to a new file.
@dataclasses.dataclass
Expand Down
122 changes: 122 additions & 0 deletions xlml/utils/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,128 @@ def generate_gpu_name() -> str:
return f"gpu-{str(uuid.uuid4())}"


@task
def get_existed_resource(
instance_name: str,
ssh_keys: ssh.SshKeys,
gcp: gcp_config.GCPConfig,
) -> airflow.XComArg:
"""Reach a resource node that is already created.

Args:
instance_name: name of the existed instance.
ssh_keys: airflow.XComArg,
gcp: GCP project/zone configuration.

Returns:
The ip address of the GPU VM.
"""
instance_client = compute_v1.InstancesClient()
instance_request = compute_v1.GetInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
)
instance = instance_client.get(request=instance_request)
logging.info(
f"Resource retrieve status: {instance.status}, {instance.status_message}"
)

ip_address = instance.network_interfaces[0].network_i_p
metadata = instance.metadata
items = metadata.items or []
ssh_key_exist = False
for item in metadata.items:
if item.key == "ssh-keys":
ssh_key_exist = True
item.value = (
item.value + "\n" + f"cloud-ml-auto-solutions:{ssh_keys.public}"
zpcore marked this conversation as resolved.
Show resolved Hide resolved
)
break
if not ssh_key_exist:
items.append({
"key": "ssh-keys",
"value": f"cloud-ml-auto-solutions:{ssh_keys.public}",
})
metadata.items = items
zpcore marked this conversation as resolved.
Show resolved Hide resolved
metadata_request = compute_v1.SetMetadataInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
metadata_resource=metadata,
)
operation = instance_client.set_metadata(request=metadata_request)
if operation.error:
logging.error(
(
"Error during instance set metadata: [Code:"
f" {operation.http_error_status_code}]:"
f" {operation.http_error_message}"
f" {operation.error}"
),
)
raise operation.exception() or RuntimeError(operation.http_error_message)
elif operation.warnings:
logging.warning("Warnings during instance set metadata:\n")
for warning in operation.warnings:
logging.warning(f" - {warning.code}: {warning.message}")

return ip_address


@task(trigger_rule="all_done")
def clean_up_ssh_keys(
instance_name: str,
ssh_keys: ssh.SshKeys,
gcp: gcp_config.GCPConfig,
) -> airflow.XComArg:
"""Remove the generated ssh_keys from existed instance.

Args:
instance_name: name of the existed instance.
ssh_keys: airflow.XComArg,
gcp: GCP project/zone configuration.
"""
instance_client = compute_v1.InstancesClient()
instance_request = compute_v1.GetInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
)
instance = instance_client.get(request=instance_request)
logging.info(
f"Resource get status: {instance.status}, {instance.status_message}"
)
metadata = instance.metadata
for item in metadata.items:
if item.key == "ssh-keys":
item.value = item.value.replace(
zpcore marked this conversation as resolved.
Show resolved Hide resolved
f"\ncloud-ml-auto-solutions:{ssh_keys.public}", ""
)
break
metadata_request = compute_v1.SetMetadataInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
metadata_resource=metadata,
)
operation = instance_client.set_metadata(request=metadata_request)
if operation.error:
logging.error(
(
"Error during instance set metadata: [Code:"
f" {operation.http_error_status_code}]:"
f" {operation.http_error_message}"
f" {operation.error}"
),
)
raise operation.exception() or RuntimeError(operation.http_error_message)
elif operation.warnings:
logging.warning("Warnings during instance set metadata:\n")
for warning in operation.warnings:
logging.warning(f" - {warning.code}: {warning.message}")


@task_group
def create_resource(
gpu_name: airflow.XComArg,
Expand Down
32 changes: 22 additions & 10 deletions xlml/utils/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ def get_gke_job_status(
def get_gce_job_status(
task_test_config: test_config.TestConfig[test_config.Accelerator],
use_startup_script: bool,
use_existed_instance: bool,
) -> bigquery.JobStatus:
"""Get job status for the GCE run.

Expand All @@ -614,9 +615,14 @@ def get_gce_job_status(
task_id=f"{benchmark_id}.provision.create_queued_resource.wait_for_ready_queued_resource"
)
elif isinstance(task_test_config, test_config.GpuVmTest):
wait_task = current_dag.get_task(
task_id=f"{benchmark_id}.provision.create_resource.get_ip_address"
)
if use_existed_instance:
wait_task = current_dag.get_task(
task_id=f"{benchmark_id}.provision.get_existed_resource"
)
else:
wait_task = current_dag.get_task(
task_id=f"{benchmark_id}.provision.create_resource.get_ip_address"
)
else:
raise NotImplementedError(
f"Unable to get task for {type(task_test_config.accelerator)}."
Expand All @@ -631,12 +637,15 @@ def get_gce_job_status(
return bigquery.JobStatus.MISSED

# check setup status to see if setup step is successful
setup_task = current_dag.get_task(task_id=f"{benchmark_id}.provision.setup")
setup_ti = TaskInstance(setup_task, execution_date)
setup_state = setup_ti.current_state()
if setup_state == TaskState.FAILED.value:
logging.info("The setup state is failed, and the job status is failed.")
return bigquery.JobStatus.FAILED
if not use_existed_instance:
jyj0w0 marked this conversation as resolved.
Show resolved Hide resolved
setup_task = current_dag.get_task(
task_id=f"{benchmark_id}.provision.setup"
)
setup_ti = TaskInstance(setup_task, execution_date)
setup_state = setup_ti.current_state()
if setup_state == TaskState.FAILED.value:
logging.info("The setup state is failed, and the job status is failed.")
return bigquery.JobStatus.FAILED

# check run_model status to see if run_model step is successful
run_model_task = current_dag.get_task(task_id=f"{benchmark_id}.run_model")
Expand Down Expand Up @@ -692,6 +701,7 @@ def process_metrics(
task_metric_config: Optional[metric_config.MetricConfig],
task_gcp_config: gcp_config.GCPConfig,
use_startup_script: bool = False,
use_existed_instance: bool = False,
folder_location: Optional[str] = None,
) -> None:
benchmark_id = task_test_config.benchmark_id
Expand Down Expand Up @@ -771,7 +781,9 @@ def process_metrics(
elif isinstance(task_test_config, test_config.GpuGkeTest):
test_job_status = get_gke_job_status(task_test_config)
else:
test_job_status = get_gce_job_status(task_test_config, use_startup_script)
test_job_status = get_gce_job_status(
task_test_config, use_startup_script, use_existed_instance
)

for index in range(len(metadata_history_rows_list)):
job_history_row = bigquery.JobHistoryRow(
Expand Down
Loading