Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse Exist GPU Instance for Auotmation #538

Merged
merged 16 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dags/inference/configs/trt_llm_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_trt_llm_gpu_config(
project: Project,
network: str,
subnetwork: str,
existing_instance_name: str = None,
) -> task.GpuCreateResourceTask:
set_up_cmds = (
"pip install --upgrade pip",
Expand Down Expand Up @@ -141,6 +142,7 @@ def get_trt_llm_gpu_config(
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.YIJIA_J,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/trt_llm",
use_existing_instance=existing_instance_name is not None,
)

job_gcp_config = gcp_config.GCPConfig(
Expand All @@ -160,4 +162,5 @@ def get_trt_llm_gpu_config(
job_test_config,
job_gcp_config,
job_metric_config,
existing_instance_name=existing_instance_name,
)
4 changes: 4 additions & 0 deletions dags/inference/configs/trt_llm_mlperf_v40_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_trt_llm_mlperf_v40_gpu_config(
project: Project,
network: str,
subnetwork: str,
existing_instance_name: str = None,
model_configs: Dict = {},
) -> task.GpuCreateResourceTask:
docker_container_name = "mlperf-inference"
Expand Down Expand Up @@ -109,6 +110,7 @@ def get_trt_llm_mlperf_v40_gpu_config(
docker_cmd = " && ".join(docker_cmds)
run_model_cmds = (
"pip install jsonlines",
f"docker restart {docker_container_name}",
f'docker exec -i {docker_container_name} /bin/bash -c "{docker_cmd}"',
make_jsonl_converter_cmd,
"cat jsonl_converter.py",
Expand All @@ -133,6 +135,7 @@ def get_trt_llm_mlperf_v40_gpu_config(
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.YIJIA_J,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/trt_llm_mlperf_v40",
use_existing_instance=existing_instance_name is not None,
)

job_gcp_config = gcp_config.GCPConfig(
Expand All @@ -152,4 +155,5 @@ def get_trt_llm_mlperf_v40_gpu_config(
job_test_config,
job_gcp_config,
job_metric_config,
existing_instance_name=existing_instance_name,
)
3 changes: 3 additions & 0 deletions dags/inference/configs/trt_llm_mlperf_v41_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_trt_llm_mlperf_gpu_config(
project: Project,
network: str,
subnetwork: str,
existing_instance_name: str = None,
benchmark_configs: Dict = {},
model_parameters: Dict = {},
parameter_positions: Dict = {},
Expand Down Expand Up @@ -196,6 +197,7 @@ def get_trt_llm_mlperf_gpu_config(
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.YIJIA_J,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/trt_llm_mlperf_v41",
use_existing_instance=existing_instance_name is not None,
)

job_gcp_config = gcp_config.GCPConfig(
Expand All @@ -215,4 +217,5 @@ def get_trt_llm_mlperf_gpu_config(
job_test_config,
job_gcp_config,
job_metric_config,
existing_instance_name=existing_instance_name,
)
1 change: 1 addition & 0 deletions dags/inference/trt_llm_mlperf_v40_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@
network=INFERENCE_NETWORKS,
subnetwork=H100_INFERENCE_SUBNETWORKS,
model_configs=model_configs,
existing_instance_name="yijiaj-test-h100x8",
).run()
1 change: 1 addition & 0 deletions dags/pytorch_xla/configs/pytorchxla_torchbench_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def get_torchbench_gpu_config(
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.PEI_Z,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/torchbench",
use_existing_instance=False,
)

job_metric_config = metric_config.MetricConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def get_gpu_vllm_gce_config(
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.RICHARD_L,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/vllm_benchmark",
use_existing_instance=False,
)

job_gcp_config = gcp_config.GCPConfig(
Expand Down
74 changes: 73 additions & 1 deletion xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ class GpuCreateResourceTask(BaseTask):
task_metric_config: metric configuration (e.g., result gcs path).
gpu_create_timeout: timeout when waiting for the GPU vm creation.
install_nvidia_drivers: whether to install Nvidia drivers.
existing_instance_name: whether an existing GPU instance shall be used.
"""

image_project: str
Expand All @@ -358,6 +359,7 @@ class GpuCreateResourceTask(BaseTask):
task_metric_config: Optional[metric_config.MetricConfig] = None
gpu_create_timeout: datetime.timedelta = datetime.timedelta(minutes=60)
install_nvidia_drivers: bool = False
existing_instance_name: str = None

def run(self) -> DAGNode:
"""Run a test job.
Expand All @@ -368,6 +370,9 @@ def run(self) -> DAGNode:
"""
# piz: We skip the queued resource for GPU for now since there is no queued
# resource command for GPU.
if self.existing_instance_name is not None:
return self.run_with_existing_instance()

with TaskGroup(
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
) as group:
Expand Down Expand Up @@ -399,6 +404,58 @@ def run(self) -> DAGNode:
provision >> run_model >> post_process >> clean_up
return group

def run_with_existing_instance(self) -> DAGNode:
"""Run a test job via existing instance.

Returns:
A task group with the following tasks chained: provision, run_model and post_process, clean_up.
"""
with TaskGroup(
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
) as group:
(
provision,
ip_address,
ssh_keys,
gcs_location,
) = self.provision_via_existing_instance()
if (
self.task_metric_config
and self.task_metric_config.use_runtime_generated_gcs_folder
):
env_variable = {
f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location
}
else:
env_variable = None
post_process = self.post_process(gcs_location)
run_model = self.run_model(ip_address, ssh_keys, env_variable)
clean_up = self.clean_up_existing_instance(ssh_keys)
provision >> run_model >> post_process >> clean_up
return group

def provision_via_existing_instance(
self,
) -> Tuple[DAGNode, airflow.XComArg, airflow.XComArg, airflow.XComArg,]:
"""Provision an existing GPU accelerator.

Returns:
A DAG node that will provision a GPU, an XCome value of the ip address
for the host,an XCom value for the SSH keys.
"""
with TaskGroup(group_id="provision") as group:
ssh_keys = ssh.generate_ssh_keys()
ip_address = gpu.get_existing_resource(
instance_name=self.existing_instance_name,
ssh_keys=ssh_keys,
gcp=self.task_gcp_config,
)
gcs_location = name_format.generate_gcs_folder_location(
self.task_test_config.gcs_subfolder,
self.task_test_config.benchmark_id,
)
return group, ip_address, ssh_keys, gcs_location

def provision(
self,
) -> Tuple[
Expand Down Expand Up @@ -476,7 +533,8 @@ def run_model(
)

def post_process(
self, result_location: Optional[airflow.XComArg] = None
self,
result_location: Optional[airflow.XComArg] = None,
) -> DAGNode:
"""Process metrics and metadata, and insert them into BigQuery tables.

Expand Down Expand Up @@ -513,6 +571,20 @@ def clean_up(
resource, project_id, zone
)

def clean_up_existing_instance(self, ssh_keys: airflow.XComArg) -> DAGNode:
"""Clean up existing GPU resources - remove the one-time use generated ssh_keys.

Args:
ssh_keys: generated GPU's one-time use SSH keys to be removed.
Returns:
A DAG node that cleaned up the ssh_keys.
"""
return gpu.clean_up_ssh_keys(
instance_name=self.existing_instance_name,
ssh_keys=ssh_keys,
gcp=self.task_gcp_config,
)


# TODO(ranran): This class is big. Let's move it to a new file.
@dataclasses.dataclass
Expand Down
2 changes: 2 additions & 0 deletions xlml/apis/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,13 @@ class GpuVmTest(TestConfig[Gpu]):
test_name: Unique name for this test/model.
set_up_cmds: List of commands to run once when GPU is created.
run_model_cmds: List of commands to run the model under test.
use_existing_instance: Whether to use an existing GPU instance.
"""

test_name: str
set_up_cmds: Iterable[str]
run_model_cmds: Iterable[str]
use_existing_instance: bool

@property
def benchmark_id(self) -> str:
Expand Down
122 changes: 122 additions & 0 deletions xlml/utils/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,128 @@ def generate_gpu_name() -> str:
return f"gpu-{str(uuid.uuid4())}"


@task
def get_existing_resource(
instance_name: str,
ssh_keys: ssh.SshKeys,
gcp: gcp_config.GCPConfig,
) -> airflow.XComArg:
"""Reach a resource node that is already created.

Args:
instance_name: name of the existing instance.
ssh_keys: airflow.XComArg,
gcp: GCP project/zone configuration.

Returns:
The ip address of the GPU VM.
"""
instance_client = compute_v1.InstancesClient()
instance_request = compute_v1.GetInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
)
instance = instance_client.get(request=instance_request)
logging.info(
f"Resource retrieve status: {instance.status}, {instance.status_message}"
)

ip_address = instance.network_interfaces[0].network_i_p
metadata = instance.metadata
items = metadata.items or []
ssh_key_exist = False
for item in metadata.items:
if item.key == "ssh-keys":
ssh_key_exist = True
item.value = (
item.value + "\n" + f"cloud-ml-auto-solutions:{ssh_keys.public}"
zpcore marked this conversation as resolved.
Show resolved Hide resolved
)
break
if not ssh_key_exist:
items.append({
"key": "ssh-keys",
"value": f"cloud-ml-auto-solutions:{ssh_keys.public}",
})
metadata.items = items
zpcore marked this conversation as resolved.
Show resolved Hide resolved
metadata_request = compute_v1.SetMetadataInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
metadata_resource=metadata,
)
operation = instance_client.set_metadata(request=metadata_request)
if operation.error:
logging.error(
(
"Error during instance set metadata: [Code:"
f" {operation.http_error_status_code}]:"
f" {operation.http_error_message}"
f" {operation.error}"
),
)
raise operation.exception() or RuntimeError(operation.http_error_message)
elif operation.warnings:
logging.warning("Warnings during instance set metadata:\n")
for warning in operation.warnings:
logging.warning(f" - {warning.code}: {warning.message}")

return ip_address


@task(trigger_rule="all_done")
def clean_up_ssh_keys(
instance_name: str,
ssh_keys: ssh.SshKeys,
gcp: gcp_config.GCPConfig,
) -> airflow.XComArg:
"""Remove the generated one-time use ssh_keys from existing instance.

Args:
instance_name: name of the existing instance.
ssh_keys: airflow.XComArg,
gcp: GCP project/zone configuration.
"""
instance_client = compute_v1.InstancesClient()
instance_request = compute_v1.GetInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
)
instance = instance_client.get(request=instance_request)
logging.info(
f"Resource get status: {instance.status}, {instance.status_message}"
)
metadata = instance.metadata
for item in metadata.items:
if item.key == "ssh-keys":
item.value = item.value.replace(
zpcore marked this conversation as resolved.
Show resolved Hide resolved
f"\ncloud-ml-auto-solutions:{ssh_keys.public}", ""
)
break
metadata_request = compute_v1.SetMetadataInstanceRequest(
instance=instance_name,
project=gcp.project_name,
zone=gcp.zone,
metadata_resource=metadata,
)
operation = instance_client.set_metadata(request=metadata_request)
if operation.error:
logging.error(
(
"Error during instance set metadata: [Code:"
f" {operation.http_error_status_code}]:"
f" {operation.http_error_message}"
f" {operation.error}"
),
)
raise operation.exception() or RuntimeError(operation.http_error_message)
elif operation.warnings:
logging.warning("Warnings during instance set metadata:\n")
for warning in operation.warnings:
logging.warning(f" - {warning.code}: {warning.message}")


@task_group
def create_resource(
gpu_name: airflow.XComArg,
Expand Down
Loading
Loading