Skip to content

Commit 58b342a

Browse files
author
Googler
committed
chore(components): Bump Starry Net images and enforce that TF Record generation always runs before test set generation to speed up pipelines runs
Signed-off-by: Googler <nobody@google.com> PiperOrigin-RevId: 655633942
1 parent 7660e8a commit 58b342a

File tree

5 files changed

+46
-39
lines changed

5 files changed

+46
-39
lines changed

components/google-cloud/RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
## Upcoming release
22
* Updated the Starry Net pipeline's template gallery description, and added dataprep_nan_threshold and dataprep_zero_threshold args to the Starry Net pipeline.
33
* Fix bug in Starry Net's upload decomposition plot step due to protobuf upgrade, by pinning protobuf library to 3.20.*.
4+
* Bump Starry Net image tags.
5+
* In the Starry-Net pipeline, enforce that TF Record generation always runs before test set generation to speed up pipelines runs.
46
* Add support for running tasks on a `PersistentResource` (see [CustomJobSpec](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/CustomJobSpec)) via `persistent_resource_id` parameter on `v1.custom_job.CustomTrainingJobOp` and `v1.custom_job.create_custom_training_job_from_component`
57
* Bump image for Structured Data pipelines.
68

components/google-cloud/google_cloud_pipeline_components/_implementation/starry_net/dataprep/component.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def dataprep(
3333
ts_identifier_columns: str,
3434
time_column: str,
3535
static_covariate_columns: str,
36+
static_covariates_vocab_path: str, # pytype: disable=unused-argument
3637
target_column: str,
3738
machine_type: str,
3839
docker_region: str,
@@ -78,6 +79,8 @@ def dataprep(
7879
data source.
7980
time_column: The column with timestamps in the BigQuery source.
8081
static_covariate_columns: The names of the staic covariates.
82+
static_covariates_vocab_path: The path to the master static covariates vocab
83+
json.
8184
target_column: The target column in the Big Query data source.
8285
machine_type: The machine type of the dataflow workers.
8386
docker_region: The docker region, used to determine which image to use.

components/google-cloud/google_cloud_pipeline_components/_implementation/starry_net/get_training_artifacts/component.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def get_training_artifacts(
5555
instance_schema_uri=str,
5656
)
5757
return outputs(
58-
f'{docker_region}-docker.pkg.dev/vertex-ai/starryn/predictor:20240617_2142_RC00', # pylint: disable=too-many-function-args
58+
f'{docker_region}-docker.pkg.dev/vertex-ai/starryn/predictor:20240723_0542_RC00', # pylint: disable=too-many-function-args
5959
private_dir, # pylint: disable=too-many-function-args
6060
os.path.join(private_dir, 'predict_schema.yaml'), # pylint: disable=too-many-function-args
6161
os.path.join(private_dir, 'instance_schema.yaml'), # pylint: disable=too-many-function-args

components/google-cloud/google_cloud_pipeline_components/_implementation/starry_net/version.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
# limitations under the License.
1414
"""Version constants for starry net components."""
1515

16-
DATAPREP_VERSION = '20240617_2225_RC00'
17-
PREDICTOR_VERSION = '20240617_2142_RC00'
18-
TRAINER_VERSION = '20240617_2142_RC00'
16+
DATAPREP_VERSION = '20240722_2225_RC00'
17+
PREDICTOR_VERSION = '20240723_0542_RC00'
18+
TRAINER_VERSION = '20240723_0542_RC00'

components/google-cloud/google_cloud_pipeline_components/preview/starry_net/component.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -239,41 +239,6 @@ def starry_net( # pylint: disable=dangerous-default-value
239239
model_blocks=trainer_model_blocks,
240240
static_covariates=dataprep_static_covariate_columns,
241241
)
242-
test_set_task = DataprepOp(
243-
backcast_length=dataprep_backcast_length,
244-
forecast_length=dataprep_forecast_length,
245-
train_end_date=dataprep_train_end_date,
246-
n_val_windows=dataprep_n_val_windows,
247-
n_test_windows=dataprep_n_test_windows,
248-
test_set_stride=dataprep_test_set_stride,
249-
model_blocks=create_dataprep_args_task.outputs['model_blocks'],
250-
bigquery_source=dataprep_bigquery_data_path,
251-
ts_identifier_columns=create_dataprep_args_task.outputs[
252-
'ts_identifier_columns'],
253-
time_column=dataprep_time_column,
254-
static_covariate_columns=create_dataprep_args_task.outputs[
255-
'static_covariate_columns'],
256-
target_column=dataprep_target_column,
257-
machine_type=dataflow_machine_type,
258-
docker_region=create_dataprep_args_task.outputs['docker_region'],
259-
location=location,
260-
project=project,
261-
job_id=job_id,
262-
job_name_prefix='test-set',
263-
num_workers=dataflow_starting_replica_count,
264-
max_num_workers=dataflow_max_replica_count,
265-
disk_size_gb=dataflow_disk_size_gb,
266-
test_set_only=True,
267-
bigquery_output=dataprep_test_set_bigquery_dataset,
268-
nan_threshold=dataprep_nan_threshold,
269-
zero_threshold=dataprep_zero_threshold,
270-
gcs_source=dataprep_csv_data_path,
271-
gcs_static_covariate_source=dataprep_csv_static_covariates_path,
272-
encryption_spec_key_name=encryption_spec_key_name
273-
)
274-
test_set_task.set_display_name('create-test-set')
275-
set_test_set_task = SetTestSetOp(
276-
dataprep_dir=test_set_task.outputs['dataprep_dir'])
277242
with dsl.If(create_dataprep_args_task.outputs['create_tf_records'] == True, # pylint: disable=singleton-comparison
278243
'create-tf-records'):
279244
create_tf_records_task = DataprepOp(
@@ -290,6 +255,7 @@ def starry_net( # pylint: disable=dangerous-default-value
290255
time_column=dataprep_time_column,
291256
static_covariate_columns=create_dataprep_args_task.outputs[
292257
'static_covariate_columns'],
258+
static_covariates_vocab_path='',
293259
target_column=dataprep_target_column,
294260
machine_type=dataflow_machine_type,
295261
docker_region=create_dataprep_args_task.outputs['docker_region'],
@@ -325,6 +291,42 @@ def starry_net( # pylint: disable=dangerous-default-value
325291
'static_covariates_vocab_path'],
326292
set_tfrecord_args_this_run_task.outputs['static_covariates_vocab_path']
327293
)
294+
test_set_task = DataprepOp(
295+
backcast_length=dataprep_backcast_length,
296+
forecast_length=dataprep_forecast_length,
297+
train_end_date=dataprep_train_end_date,
298+
n_val_windows=dataprep_n_val_windows,
299+
n_test_windows=dataprep_n_test_windows,
300+
test_set_stride=dataprep_test_set_stride,
301+
model_blocks=create_dataprep_args_task.outputs['model_blocks'],
302+
bigquery_source=dataprep_bigquery_data_path,
303+
ts_identifier_columns=create_dataprep_args_task.outputs[
304+
'ts_identifier_columns'],
305+
time_column=dataprep_time_column,
306+
static_covariate_columns=create_dataprep_args_task.outputs[
307+
'static_covariate_columns'],
308+
static_covariates_vocab_path=static_covariates_vocab_path,
309+
target_column=dataprep_target_column,
310+
machine_type=dataflow_machine_type,
311+
docker_region=create_dataprep_args_task.outputs['docker_region'],
312+
location=location,
313+
project=project,
314+
job_id=job_id,
315+
job_name_prefix='test-set',
316+
num_workers=dataflow_starting_replica_count,
317+
max_num_workers=dataflow_max_replica_count,
318+
disk_size_gb=dataflow_disk_size_gb,
319+
test_set_only=True,
320+
bigquery_output=dataprep_test_set_bigquery_dataset,
321+
nan_threshold=dataprep_nan_threshold,
322+
zero_threshold=dataprep_zero_threshold,
323+
gcs_source=dataprep_csv_data_path,
324+
gcs_static_covariate_source=dataprep_csv_static_covariates_path,
325+
encryption_spec_key_name=encryption_spec_key_name
326+
)
327+
test_set_task.set_display_name('create-test-set')
328+
set_test_set_task = SetTestSetOp(
329+
dataprep_dir=test_set_task.outputs['dataprep_dir'])
328330
train_tf_record_patterns = dsl.OneOf(
329331
set_tfrecord_args_previous_run_task.outputs['train_tf_record_patterns'],
330332
set_tfrecord_args_this_run_task.outputs['train_tf_record_patterns']

0 commit comments

Comments
 (0)