Skip to content

Commit f1e2314

Browse files
chensunGoogler
andauthored
Expose starry_net yaml to GitHub (#10943)
Fix GitHub test error: ``` > with open(file_path, 'r') as component_stream: E FileNotFoundError: [Errno 2] No such file or directory: '/usr/local/lib/python3.8/site-packages/google_cloud_pipeline_components/_implementation/starry_net/evaluation/evaluation.yaml' /usr/local/lib/python3.8/site-packages/kfp/components/load_yaml_utilities.py:53: FileNotFoundError ``` PiperOrigin-RevId: 645623029 Signed-off-by: Googler <nobody@google.com> Co-authored-by: Googler <nobody@google.com>
1 parent 48b2d3f commit f1e2314

File tree

2 files changed

+234
-0
lines changed

2 files changed

+234
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
name: model_evaluation_forecasting
2+
description: |
3+
Computes a google.ForecastingMetrics Artifact, containing evaluation metrics given a model's prediction results.
4+
Creates a dataflow job with Apache Beam and TFMA to compute evaluation metrics.
5+
Supports point forecasting and quantile forecasting for tabular data.
6+
Args:
7+
project (str):
8+
Project to run evaluation container.
9+
location (Optional[str]):
10+
Location for running the evaluation.
11+
If not set, defaulted to `us-central1`.
12+
root_dir (str):
13+
The GCS directory for keeping staging files.
14+
A random subdirectory will be created under the directory to keep job info for resuming
15+
the job in case of failure.
16+
predictions_format (Optional[str]):
17+
The file format for the batch prediction results. `jsonl` is currently the only allowed
18+
format.
19+
If not set, defaulted to `jsonl`.
20+
predictions_gcs_source (Optional[system.Artifact]):
21+
An artifact with its URI pointing toward a GCS directory with prediction or explanation
22+
files to be used for this evaluation.
23+
For prediction results, the files should be named "prediction.results-*".
24+
For explanation results, the files should be named "explanation.results-*".
25+
predictions_bigquery_source (Optional[google.BQTable]):
26+
BigQuery table with prediction or explanation data to be used for this evaluation.
27+
For prediction results, the table column should be named "predicted_*".
28+
ground_truth_format(Optional[str]):
29+
Required for custom tabular and non tabular data.
30+
The file format for the ground truth files. `jsonl` is currently the only allowed format.
31+
If not set, defaulted to `jsonl`.
32+
ground_truth_gcs_source(Optional[Sequence[str]]):
33+
Required for custom tabular and non tabular data.
34+
The GCS uris representing where the ground truth is located.
35+
Used to provide ground truth for each prediction instance when they are not part of the batch prediction jobs prediction instance.
36+
ground_truth_bigquery_source(Optional[str]):
37+
Required for custom tabular.
38+
The BigQuery table uri representing where the ground truth is located.
39+
Used to provide ground truth for each prediction instance when they are not part of the batch prediction jobs prediction instance.
40+
target_field_name (str):
41+
The full name path of the features target field in the predictions file.
42+
Formatted to be able to find nested columns, delimited by `.`.
43+
Alternatively referred to as the ground truth (or ground_truth_column) field.
44+
model (Optional[google.VertexModel]):
45+
The Model used for predictions job.
46+
Must share the same ancestor Location.
47+
prediction_score_column (Optional[str]):
48+
Optional. The column name of the field containing batch prediction scores.
49+
Formatted to be able to find nested columns, delimited by `.`.
50+
If not set, defaulted to `prediction.value` for a `point` forecasting_type and
51+
`prediction.quantile_predictions` for a `quantile` forecasting_type.
52+
forecasting_type (Optional[str]):
53+
Optional. If the problem_type is `forecasting`, then the forecasting type being addressed
54+
by this regression evaluation run. `point` and `quantile` are the supported types.
55+
If not set, defaulted to `point`.
56+
forecasting_quantiles (Optional[Sequence[Float]]):
57+
Required for a `quantile` forecasting_type.
58+
The list of quantiles in the same order appeared in the quantile prediction score column.
59+
If one of the quantiles is set to `0.5f`, point evaluation will be set on that index.
60+
example_weight_column (Optional[str]):
61+
Optional. The column name of the field containing example weights.
62+
Each value of positive_classes provided.
63+
point_evaluation_quantile (Optional[Float]):
64+
Required for a `quantile` forecasting_type.
65+
A quantile in the list of forecasting_quantiles that will be used for point evaluation
66+
metrics.
67+
dataflow_service_account (Optional[str]):
68+
Optional. Service account to run the dataflow job.
69+
If not set, dataflow will use the default woker service account.
70+
For more details, see https://cloud.google.com/dataflow/docs/concepts/security-and-permissions#default_worker_service_account
71+
dataflow_disk_size (Optional[int]):
72+
Optional. The disk size (in GB) of the machine executing the evaluation run.
73+
If not set, defaulted to `50`.
74+
dataflow_machine_type (Optional[str]):
75+
Optional. The machine type executing the evaluation run.
76+
If not set, defaulted to `n1-standard-4`.
77+
dataflow_workers_num (Optional[int]):
78+
Optional. The number of workers executing the evaluation run.
79+
If not set, defaulted to `10`.
80+
dataflow_max_workers_num (Optional[int]):
81+
Optional. The max number of workers executing the evaluation run.
82+
If not set, defaulted to `25`.
83+
dataflow_subnetwork (Optional[str]):
84+
Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be
85+
used. More details:
86+
https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications
87+
dataflow_use_public_ips (Optional[bool]):
88+
Specifies whether Dataflow workers use public IP addresses.
89+
encryption_spec_key_name (Optional[str]):
90+
Customer-managed encryption key.
91+
Returns:
92+
evaluation_metrics (google.ForecastingMetrics):
93+
google.ForecastingMetrics artifact representing the forecasting evaluation metrics in GCS.
94+
inputs:
95+
- { name: project, type: String }
96+
- { name: location, type: String, default: "us-central1" }
97+
- { name: root_dir, type: system.Artifact }
98+
- { name: predictions_format, type: String, default: "jsonl" }
99+
- { name: predictions_gcs_source, type: Artifact, optional: True }
100+
- { name: predictions_bigquery_source, type: google.BQTable, optional: True }
101+
- { name: ground_truth_format, type: String, default: "jsonl" }
102+
- { name: ground_truth_gcs_source, type: JsonArray, default: "[]" }
103+
- { name: ground_truth_bigquery_source, type: String, default: "" }
104+
- { name: target_field_name, type: String }
105+
- { name: model, type: google.VertexModel, optional: True }
106+
- { name: prediction_score_column, type: String, default: "" }
107+
- { name: forecasting_type, type: String, default: "point" }
108+
- { name: forecasting_quantiles, type: JsonArray, default: "[0.5]" }
109+
- { name: example_weight_column, type: String, default: "" }
110+
- { name: point_evaluation_quantile, type: Float, default: 0.5 }
111+
- { name: dataflow_service_account, type: String, default: "" }
112+
- { name: dataflow_disk_size, type: Integer, default: 50 }
113+
- { name: dataflow_machine_type, type: String, default: "n1-standard-4" }
114+
- { name: dataflow_workers_num, type: Integer, default: 1 }
115+
- { name: dataflow_max_workers_num, type: Integer, default: 5 }
116+
- { name: dataflow_subnetwork, type: String, default: "" }
117+
- { name: dataflow_use_public_ips, type: Boolean, default: "true" }
118+
- { name: encryption_spec_key_name, type: String, default: "" }
119+
outputs:
120+
- { name: evaluation_metrics, type: google.ForecastingMetrics }
121+
- { name: gcp_resources, type: String }
122+
implementation:
123+
container:
124+
image: gcr.io/ml-pipeline/model-evaluation:v0.9
125+
command:
126+
- python
127+
- /main.py
128+
args:
129+
- --setup_file
130+
- /setup.py
131+
- --json_mode
132+
- "true"
133+
- --project_id
134+
- { inputValue: project }
135+
- --location
136+
- { inputValue: location }
137+
- --problem_type
138+
- "forecasting"
139+
- --forecasting_type
140+
- { inputValue: forecasting_type }
141+
- --forecasting_quantiles
142+
- { inputValue: forecasting_quantiles }
143+
- --point_evaluation_quantile
144+
- { inputValue: point_evaluation_quantile }
145+
- --batch_prediction_format
146+
- { inputValue: predictions_format }
147+
- if:
148+
cond: {isPresent: predictions_gcs_source}
149+
then:
150+
- --batch_prediction_gcs_source
151+
- "{{$.inputs.artifacts['predictions_gcs_source'].uri}}"
152+
- if:
153+
cond: {isPresent: predictions_bigquery_source}
154+
then:
155+
- --batch_prediction_bigquery_source
156+
- "bq://{{$.inputs.artifacts['predictions_bigquery_source'].metadata['projectId']}}.{{$.inputs.artifacts['predictions_bigquery_source'].metadata['datasetId']}}.{{$.inputs.artifacts['predictions_bigquery_source'].metadata['tableId']}}"
157+
- if:
158+
cond: {isPresent: model}
159+
then:
160+
- --model_name
161+
- "{{$.inputs.artifacts['model'].metadata['resourceName']}}"
162+
- --ground_truth_format
163+
- { inputValue: ground_truth_format }
164+
- --ground_truth_gcs_source
165+
- { inputValue: ground_truth_gcs_source }
166+
- --ground_truth_bigquery_source
167+
- { inputValue: ground_truth_bigquery_source }
168+
- --root_dir
169+
- "{{$.inputs.artifacts['root_dir'].uri}}"
170+
- --target_field_name
171+
- "instance.{{$.inputs.parameters['target_field_name']}}"
172+
- --prediction_score_column
173+
- { inputValue: prediction_score_column }
174+
- --dataflow_job_prefix
175+
- "evaluation-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}"
176+
- --dataflow_service_account
177+
- { inputValue: dataflow_service_account }
178+
- --dataflow_disk_size
179+
- { inputValue: dataflow_disk_size }
180+
- --dataflow_machine_type
181+
- { inputValue: dataflow_machine_type }
182+
- --dataflow_workers_num
183+
- { inputValue: dataflow_workers_num }
184+
- --dataflow_max_workers_num
185+
- { inputValue: dataflow_max_workers_num }
186+
- --dataflow_subnetwork
187+
- { inputValue: dataflow_subnetwork }
188+
- --dataflow_use_public_ips
189+
- { inputValue: dataflow_use_public_ips }
190+
- --kms_key_name
191+
- { inputValue: encryption_spec_key_name }
192+
- --output_metrics_gcs_path
193+
- { outputUri: evaluation_metrics }
194+
- --gcp_resources
195+
- { outputPath: gcp_resources }
196+
- --executor_input
197+
- "{{$}}"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: model_upload
2+
inputs:
3+
- {name: project, type: String}
4+
- {name: location, type: String, default: "us-central1"}
5+
- {name: display_name, type: String}
6+
- {name: description, type: String, optional: true, default: ''}
7+
- {name: unmanaged_container_model, type: google.UnmanagedContainerModel, optional: true}
8+
- {name: encryption_spec_key_name, type: String, optional: true, default: ''}
9+
- {name: labels, type: JsonObject, optional: true, default: '{}'}
10+
- {name: parent_model, type: google.VertexModel, optional: true}
11+
outputs:
12+
- {name: model, type: google.VertexModel}
13+
- {name: gcp_resources, type: String}
14+
implementation:
15+
container:
16+
image: gcr.io/ml-pipeline/automl-tables-private:1.0.17
17+
command: [python3, -u, -m, launcher]
18+
args: [
19+
--type, UploadModel,
20+
--payload,
21+
concat: [
22+
'{',
23+
'"display_name": "', {inputValue: display_name}, '"',
24+
', "description": "', {inputValue: description}, '"',
25+
', "encryption_spec": {"kms_key_name":"', {inputValue: encryption_spec_key_name}, '"}',
26+
', "labels": ', {inputValue: labels},
27+
'}'
28+
],
29+
--project, {inputValue: project},
30+
--location, {inputValue: location},
31+
--gcp_resources, {outputPath: gcp_resources},
32+
--executor_input, "{{$}}",
33+
{if: {
34+
cond: {isPresent: parent_model},
35+
then: ["--parent_model_name", "{{$.inputs.artifacts['parent_model'].metadata['resourceName']}}",]
36+
}},
37+
]

0 commit comments

Comments
 (0)