From 7908ed664653143d335ba3e9227484347e64577d Mon Sep 17 00:00:00 2001
From: Googler <nobody@google.com>
Date: Mon, 6 May 2024 20:01:39 -0700
Subject: [PATCH] feat(components): use preprocessor utility methods for the
 upload model graph

PiperOrigin-RevId: 631266689
---
 .../_implementation/llm/deployment_graph.py   | 26 +++++--------------
 .../llm/generated/refined_image_versions.py   |  2 +-
 .../_implementation/llm/rlhf_preprocessor.py  | 15 +++++++++++
 .../preview/llm/rlhf/component.py             |  9 +++++--
 4 files changed, 29 insertions(+), 23 deletions(-)

diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/deployment_graph.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/deployment_graph.py
index 92cf7412387..eb4190debf1 100644
--- a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/deployment_graph.py
+++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/deployment_graph.py
@@ -37,6 +37,7 @@ def pipeline(
     policy_model_reference: str,
     model_display_name: Optional[str] = None,
     deploy_model: bool = True,
+    upload_model: bool = True,
     encryption_spec_key_name: str = '',
     upload_location: str = _placeholders.LOCATION_PLACEHOLDER,
     regional_endpoint: str = '',
@@ -59,40 +60,25 @@ def pipeline(
     endpoint_resource_name: Path the Online Prediction Endpoint. This will be an empty string if the model was not deployed.
   """
   # fmt: on
-  display_name = (
-      function_based.resolve_model_display_name(
-          large_model_reference=large_model_reference,
-          model_display_name=model_display_name,
-      )
-      .set_caching_options(False)
-      .set_display_name('Resolve Model Display Name')
-  )
-
-  upload_model = function_based.resolve_upload_model(
-      large_model_reference=policy_model_reference,
-  ).set_display_name('Resolve Upload Model')
   upload_task = upload_llm_model.refined_upload_llm_model(
       project=_placeholders.PROJECT_ID_PLACEHOLDER,
       location=upload_location,
       regional_endpoint=regional_endpoint,
       artifact_uri=output_adapter_path,
-      model_display_name=display_name.output,
+      model_display_name=model_display_name,
       model_reference_name=large_model_reference,
-      upload_model=upload_model.output,
+      upload_model=upload_model,
       encryption_spec_key_name=encryption_spec_key_name,
       tune_type='rlhf',
   ).set_display_name('Upload Model')
-  deploy_model = function_based.resolve_deploy_model(
-      deploy_model=deploy_model,
-      large_model_reference=policy_model_reference,
-  ).set_display_name('Resolve Deploy Model')
+
   deploy_task = deploy_llm_model.deploy_llm_model(
       project=_placeholders.PROJECT_ID_PLACEHOLDER,
       location=upload_location,
       model_resource_name=upload_task.outputs['model_resource_name'],
-      display_name=display_name.output,
+      display_name=model_display_name,
       regional_endpoint=regional_endpoint,
-      deploy_model=deploy_model.output,
+      deploy_model=deploy_model,
       encryption_spec_key_name=encryption_spec_key_name,
   ).set_display_name('Deploy Model')
   return PipelineOutput(
diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py
index 40496daa087..6f350e884e5 100644
--- a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py
+++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py
@@ -17,4 +17,4 @@
 DO NOT EDIT - This file is generated, manual changes will be overridden.
 """
 
-IMAGE_TAG = '20240502_1327_RC00'
+IMAGE_TAG = '20240506_1530_RC00'
diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py
index cb472df9dcb..06906d0a476 100644
--- a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py
+++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py
@@ -46,12 +46,17 @@ def rlhf_preprocessor(
     metadata_refined_image_uri: dsl.OutputPath(str),  # pytype: disable=invalid-annotation
     metadata_num_microbatches: dsl.OutputPath(int),  # pytype: disable=invalid-annotation
     metadata_upload_location: dsl.OutputPath(str),  # pytype: disable=invalid-annotation
+    metadata_deploy_model: dsl.OutputPath(bool),  # pytype: disable=invalid-annotation
+    metadata_model_display_name: dsl.OutputPath(str),  # pytype: disable=invalid-annotation
+    metadata_upload_model: dsl.OutputPath(bool),  # pytype: disable=invalid-annotation
     use_experimental_image: bool = False,
     evaluation_dataset: str = '',
     tensorboard_resource_id: str = '',
     input_reference_model_path: str = '',
     image_uri: str = utils.get_default_image_uri('refined_cpu', ''),
     upload_location: str = '',
+    model_display_name: str = '',
+    deploy_model: bool = True,
 ) -> dsl.ContainerSpec:  # pylint: disable=g-doc-args
   # fmt: off
   """Preprocess RLHF pipeline inputs.
@@ -73,6 +78,8 @@ def rlhf_preprocessor(
     metadata_reward_model_path: The model checkpoint path for the reward model.
     image_uri: Docker image URI to use for the custom job.
     upload_location: Region where the model will be uploaded.
+    model_display_name: Display name of the model.
+    deploy_model: Whether to deploy the model.
 
   Returns:
     gcp_resources: GCP resources that can be used to track the custom job.
@@ -86,6 +93,9 @@ def rlhf_preprocessor(
     metadata_num_microbatches: Number of microbatches to break the total batch
       size into during training.
     metadata_upload_location: Regional endpoint.
+    metadata_deploy_model: Whether to deploy the model.
+    metadata_model_display_name: Display name of the model.
+    metadata_upload_model: Whether to upload the model.
   """
   # fmt: on
   return gcpc_utils.build_serverless_customjob_container_spec(
@@ -109,6 +119,8 @@ def rlhf_preprocessor(
               f'--tag={tag}',
               f'--use_experimental_image={use_experimental_image}',
               f'--upload_location={upload_location}',
+              f'--deploy_model={deploy_model}',
+              f'--model_display_name={model_display_name}',
               f'--has_tensorboard_id_path={has_tensorboard_id}',
               f'--has_inference_dataset_path={has_inference_dataset}',
               f'--metadata_candidate_columns_string_path={metadata_candidate_columns_string}',
@@ -123,6 +135,9 @@ def rlhf_preprocessor(
               f'--metadata_refined_image_uri_path={metadata_refined_image_uri}',
               f'--metadata_num_microbatches_path={metadata_num_microbatches}',
               f'--metadata_upload_location_path={metadata_upload_location}',
+              f'--metadata_deploy_model_path={metadata_deploy_model}',
+              f'--metadata_model_display_name_path={metadata_model_display_name}',
+              f'--metadata_upload_model_path={metadata_upload_model}',
           ],
       ),
       gcp_resources=gcp_resources,
diff --git a/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py b/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py
index eb8ee6cc772..c61aeebc576 100644
--- a/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py
+++ b/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py
@@ -107,6 +107,8 @@ def rlhf_pipeline(
       evaluation_dataset=eval_dataset,
       tensorboard_resource_id=tensorboard_resource_id,
       upload_location=location,
+      model_display_name=model_display_name,
+      deploy_model=deploy_model,
   ).set_display_name('Preprocess Inputs')
   num_microbatches = preprocess_metadata.outputs['metadata_num_microbatches']
 
@@ -230,8 +232,11 @@ def rlhf_pipeline(
       policy_model_reference=preprocess_metadata.outputs[
           'metadata_large_model_reference'
       ],
-      model_display_name=model_display_name,
-      deploy_model=deploy_model,
+      model_display_name=preprocess_metadata.outputs[
+          'metadata_model_display_name'
+      ],
+      deploy_model=preprocess_metadata.outputs['metadata_deploy_model'],
+      upload_model=preprocess_metadata.outputs['metadata_upload_model'],
       encryption_spec_key_name=encryption_spec_key_name,
       upload_location=location,
       regional_endpoint=preprocess_metadata.outputs['metadata_upload_location'],