Skip to content

Commit

Permalink
feat(components): use preprocessor utility methods for the upload mod…
Browse files Browse the repository at this point in the history
…el graph

PiperOrigin-RevId: 631266689
  • Loading branch information
Googler committed May 7, 2024
1 parent 804ee00 commit 7908ed6
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def pipeline(
policy_model_reference: str,
model_display_name: Optional[str] = None,
deploy_model: bool = True,
upload_model: bool = True,
encryption_spec_key_name: str = '',
upload_location: str = _placeholders.LOCATION_PLACEHOLDER,
regional_endpoint: str = '',
Expand All @@ -59,40 +60,25 @@ def pipeline(
endpoint_resource_name: Path the Online Prediction Endpoint. This will be an empty string if the model was not deployed.
"""
# fmt: on
display_name = (
function_based.resolve_model_display_name(
large_model_reference=large_model_reference,
model_display_name=model_display_name,
)
.set_caching_options(False)
.set_display_name('Resolve Model Display Name')
)

upload_model = function_based.resolve_upload_model(
large_model_reference=policy_model_reference,
).set_display_name('Resolve Upload Model')
upload_task = upload_llm_model.refined_upload_llm_model(
project=_placeholders.PROJECT_ID_PLACEHOLDER,
location=upload_location,
regional_endpoint=regional_endpoint,
artifact_uri=output_adapter_path,
model_display_name=display_name.output,
model_display_name=model_display_name,
model_reference_name=large_model_reference,
upload_model=upload_model.output,
upload_model=upload_model,
encryption_spec_key_name=encryption_spec_key_name,
tune_type='rlhf',
).set_display_name('Upload Model')
deploy_model = function_based.resolve_deploy_model(
deploy_model=deploy_model,
large_model_reference=policy_model_reference,
).set_display_name('Resolve Deploy Model')

deploy_task = deploy_llm_model.deploy_llm_model(
project=_placeholders.PROJECT_ID_PLACEHOLDER,
location=upload_location,
model_resource_name=upload_task.outputs['model_resource_name'],
display_name=display_name.output,
display_name=model_display_name,
regional_endpoint=regional_endpoint,
deploy_model=deploy_model.output,
deploy_model=deploy_model,
encryption_spec_key_name=encryption_spec_key_name,
).set_display_name('Deploy Model')
return PipelineOutput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
DO NOT EDIT - This file is generated, manual changes will be overridden.
"""

IMAGE_TAG = '20240502_1327_RC00'
IMAGE_TAG = '20240506_1530_RC00'
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,17 @@ def rlhf_preprocessor(
metadata_refined_image_uri: dsl.OutputPath(str), # pytype: disable=invalid-annotation
metadata_num_microbatches: dsl.OutputPath(int), # pytype: disable=invalid-annotation
metadata_upload_location: dsl.OutputPath(str), # pytype: disable=invalid-annotation
metadata_deploy_model: dsl.OutputPath(bool), # pytype: disable=invalid-annotation
metadata_model_display_name: dsl.OutputPath(str), # pytype: disable=invalid-annotation
metadata_upload_model: dsl.OutputPath(bool), # pytype: disable=invalid-annotation
use_experimental_image: bool = False,
evaluation_dataset: str = '',
tensorboard_resource_id: str = '',
input_reference_model_path: str = '',
image_uri: str = utils.get_default_image_uri('refined_cpu', ''),
upload_location: str = '',
model_display_name: str = '',
deploy_model: bool = True,
) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
# fmt: off
"""Preprocess RLHF pipeline inputs.
Expand All @@ -73,6 +78,8 @@ def rlhf_preprocessor(
metadata_reward_model_path: The model checkpoint path for the reward model.
image_uri: Docker image URI to use for the custom job.
upload_location: Region where the model will be uploaded.
model_display_name: Display name of the model.
deploy_model: Whether to deploy the model.
Returns:
gcp_resources: GCP resources that can be used to track the custom job.
Expand All @@ -86,6 +93,9 @@ def rlhf_preprocessor(
metadata_num_microbatches: Number of microbatches to break the total batch
size into during training.
metadata_upload_location: Regional endpoint.
metadata_deploy_model: Whether to deploy the model.
metadata_model_display_name: Display name of the model.
metadata_upload_model: Whether to upload the model.
"""
# fmt: on
return gcpc_utils.build_serverless_customjob_container_spec(
Expand All @@ -109,6 +119,8 @@ def rlhf_preprocessor(
f'--tag={tag}',
f'--use_experimental_image={use_experimental_image}',
f'--upload_location={upload_location}',
f'--deploy_model={deploy_model}',
f'--model_display_name={model_display_name}',
f'--has_tensorboard_id_path={has_tensorboard_id}',
f'--has_inference_dataset_path={has_inference_dataset}',
f'--metadata_candidate_columns_string_path={metadata_candidate_columns_string}',
Expand All @@ -123,6 +135,9 @@ def rlhf_preprocessor(
f'--metadata_refined_image_uri_path={metadata_refined_image_uri}',
f'--metadata_num_microbatches_path={metadata_num_microbatches}',
f'--metadata_upload_location_path={metadata_upload_location}',
f'--metadata_deploy_model_path={metadata_deploy_model}',
f'--metadata_model_display_name_path={metadata_model_display_name}',
f'--metadata_upload_model_path={metadata_upload_model}',
],
),
gcp_resources=gcp_resources,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def rlhf_pipeline(
evaluation_dataset=eval_dataset,
tensorboard_resource_id=tensorboard_resource_id,
upload_location=location,
model_display_name=model_display_name,
deploy_model=deploy_model,
).set_display_name('Preprocess Inputs')
num_microbatches = preprocess_metadata.outputs['metadata_num_microbatches']

Expand Down Expand Up @@ -230,8 +232,11 @@ def rlhf_pipeline(
policy_model_reference=preprocess_metadata.outputs[
'metadata_large_model_reference'
],
model_display_name=model_display_name,
deploy_model=deploy_model,
model_display_name=preprocess_metadata.outputs[
'metadata_model_display_name'
],
deploy_model=preprocess_metadata.outputs['metadata_deploy_model'],
upload_model=preprocess_metadata.outputs['metadata_upload_model'],
encryption_spec_key_name=encryption_spec_key_name,
upload_location=location,
regional_endpoint=preprocess_metadata.outputs['metadata_upload_location'],
Expand Down

0 comments on commit 7908ed6

Please sign in to comment.