diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/deployment_graph.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/deployment_graph.py index 92cf7412387..eb4190debf1 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/deployment_graph.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/deployment_graph.py @@ -37,6 +37,7 @@ def pipeline( policy_model_reference: str, model_display_name: Optional[str] = None, deploy_model: bool = True, + upload_model: bool = True, encryption_spec_key_name: str = '', upload_location: str = _placeholders.LOCATION_PLACEHOLDER, regional_endpoint: str = '', @@ -59,40 +60,25 @@ def pipeline( endpoint_resource_name: Path the Online Prediction Endpoint. This will be an empty string if the model was not deployed. """ # fmt: on - display_name = ( - function_based.resolve_model_display_name( - large_model_reference=large_model_reference, - model_display_name=model_display_name, - ) - .set_caching_options(False) - .set_display_name('Resolve Model Display Name') - ) - - upload_model = function_based.resolve_upload_model( - large_model_reference=policy_model_reference, - ).set_display_name('Resolve Upload Model') upload_task = upload_llm_model.refined_upload_llm_model( project=_placeholders.PROJECT_ID_PLACEHOLDER, location=upload_location, regional_endpoint=regional_endpoint, artifact_uri=output_adapter_path, - model_display_name=display_name.output, + model_display_name=model_display_name, model_reference_name=large_model_reference, - upload_model=upload_model.output, + upload_model=upload_model, encryption_spec_key_name=encryption_spec_key_name, tune_type='rlhf', ).set_display_name('Upload Model') - deploy_model = function_based.resolve_deploy_model( - deploy_model=deploy_model, - large_model_reference=policy_model_reference, - ).set_display_name('Resolve Deploy Model') + deploy_task = deploy_llm_model.deploy_llm_model( project=_placeholders.PROJECT_ID_PLACEHOLDER, location=upload_location, model_resource_name=upload_task.outputs['model_resource_name'], - display_name=display_name.output, + display_name=model_display_name, regional_endpoint=regional_endpoint, - deploy_model=deploy_model.output, + deploy_model=deploy_model, encryption_spec_key_name=encryption_spec_key_name, ).set_display_name('Deploy Model') return PipelineOutput( diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py index 40496daa087..6f350e884e5 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py @@ -17,4 +17,4 @@ DO NOT EDIT - This file is generated, manual changes will be overridden. """ -IMAGE_TAG = '20240502_1327_RC00' +IMAGE_TAG = '20240506_1530_RC00' diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py index cb472df9dcb..06906d0a476 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py @@ -46,12 +46,17 @@ def rlhf_preprocessor( metadata_refined_image_uri: dsl.OutputPath(str), # pytype: disable=invalid-annotation metadata_num_microbatches: dsl.OutputPath(int), # pytype: disable=invalid-annotation metadata_upload_location: dsl.OutputPath(str), # pytype: disable=invalid-annotation + metadata_deploy_model: dsl.OutputPath(bool), # pytype: disable=invalid-annotation + metadata_model_display_name: dsl.OutputPath(str), # pytype: disable=invalid-annotation + metadata_upload_model: dsl.OutputPath(bool), # pytype: disable=invalid-annotation use_experimental_image: bool = False, evaluation_dataset: str = '', tensorboard_resource_id: str = '', input_reference_model_path: str = '', image_uri: str = utils.get_default_image_uri('refined_cpu', ''), upload_location: str = '', + model_display_name: str = '', + deploy_model: bool = True, ) -> dsl.ContainerSpec: # pylint: disable=g-doc-args # fmt: off """Preprocess RLHF pipeline inputs. @@ -73,6 +78,8 @@ def rlhf_preprocessor( metadata_reward_model_path: The model checkpoint path for the reward model. image_uri: Docker image URI to use for the custom job. upload_location: Region where the model will be uploaded. + model_display_name: Display name of the model. + deploy_model: Whether to deploy the model. Returns: gcp_resources: GCP resources that can be used to track the custom job. @@ -86,6 +93,9 @@ def rlhf_preprocessor( metadata_num_microbatches: Number of microbatches to break the total batch size into during training. metadata_upload_location: Regional endpoint. + metadata_deploy_model: Whether to deploy the model. + metadata_model_display_name: Display name of the model. + metadata_upload_model: Whether to upload the model. """ # fmt: on return gcpc_utils.build_serverless_customjob_container_spec( @@ -109,6 +119,8 @@ def rlhf_preprocessor( f'--tag={tag}', f'--use_experimental_image={use_experimental_image}', f'--upload_location={upload_location}', + f'--deploy_model={deploy_model}', + f'--model_display_name={model_display_name}', f'--has_tensorboard_id_path={has_tensorboard_id}', f'--has_inference_dataset_path={has_inference_dataset}', f'--metadata_candidate_columns_string_path={metadata_candidate_columns_string}', @@ -123,6 +135,9 @@ def rlhf_preprocessor( f'--metadata_refined_image_uri_path={metadata_refined_image_uri}', f'--metadata_num_microbatches_path={metadata_num_microbatches}', f'--metadata_upload_location_path={metadata_upload_location}', + f'--metadata_deploy_model_path={metadata_deploy_model}', + f'--metadata_model_display_name_path={metadata_model_display_name}', + f'--metadata_upload_model_path={metadata_upload_model}', ], ), gcp_resources=gcp_resources, diff --git a/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py b/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py index eb8ee6cc772..c61aeebc576 100644 --- a/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py @@ -107,6 +107,8 @@ def rlhf_pipeline( evaluation_dataset=eval_dataset, tensorboard_resource_id=tensorboard_resource_id, upload_location=location, + model_display_name=model_display_name, + deploy_model=deploy_model, ).set_display_name('Preprocess Inputs') num_microbatches = preprocess_metadata.outputs['metadata_num_microbatches'] @@ -230,8 +232,11 @@ def rlhf_pipeline( policy_model_reference=preprocess_metadata.outputs[ 'metadata_large_model_reference' ], - model_display_name=model_display_name, - deploy_model=deploy_model, + model_display_name=preprocess_metadata.outputs[ + 'metadata_model_display_name' + ], + deploy_model=preprocess_metadata.outputs['metadata_deploy_model'], + upload_model=preprocess_metadata.outputs['metadata_upload_model'], encryption_spec_key_name=encryption_spec_key_name, upload_location=location, regional_endpoint=preprocess_metadata.outputs['metadata_upload_location'],