|
| 1 | +name: model_evaluation_forecasting |
| 2 | +description: | |
| 3 | + Computes a google.ForecastingMetrics Artifact, containing evaluation metrics given a model's prediction results. |
| 4 | + Creates a dataflow job with Apache Beam and TFMA to compute evaluation metrics. |
| 5 | + Supports point forecasting and quantile forecasting for tabular data. |
| 6 | + Args: |
| 7 | + project (str): |
| 8 | + Project to run evaluation container. |
| 9 | + location (Optional[str]): |
| 10 | + Location for running the evaluation. |
| 11 | + If not set, defaulted to `us-central1`. |
| 12 | + root_dir (str): |
| 13 | + The GCS directory for keeping staging files. |
| 14 | + A random subdirectory will be created under the directory to keep job info for resuming |
| 15 | + the job in case of failure. |
| 16 | + predictions_format (Optional[str]): |
| 17 | + The file format for the batch prediction results. `jsonl` is currently the only allowed |
| 18 | + format. |
| 19 | + If not set, defaulted to `jsonl`. |
| 20 | + predictions_gcs_source (Optional[system.Artifact]): |
| 21 | + An artifact with its URI pointing toward a GCS directory with prediction or explanation |
| 22 | + files to be used for this evaluation. |
| 23 | + For prediction results, the files should be named "prediction.results-*". |
| 24 | + For explanation results, the files should be named "explanation.results-*". |
| 25 | + predictions_bigquery_source (Optional[google.BQTable]): |
| 26 | + BigQuery table with prediction or explanation data to be used for this evaluation. |
| 27 | + For prediction results, the table column should be named "predicted_*". |
| 28 | + ground_truth_format(Optional[str]): |
| 29 | + Required for custom tabular and non tabular data. |
| 30 | + The file format for the ground truth files. `jsonl` is currently the only allowed format. |
| 31 | + If not set, defaulted to `jsonl`. |
| 32 | + ground_truth_gcs_source(Optional[Sequence[str]]): |
| 33 | + Required for custom tabular and non tabular data. |
| 34 | + The GCS uris representing where the ground truth is located. |
| 35 | + Used to provide ground truth for each prediction instance when they are not part of the batch prediction jobs prediction instance. |
| 36 | + ground_truth_bigquery_source(Optional[str]): |
| 37 | + Required for custom tabular. |
| 38 | + The BigQuery table uri representing where the ground truth is located. |
| 39 | + Used to provide ground truth for each prediction instance when they are not part of the batch prediction jobs prediction instance. |
| 40 | + target_field_name (str): |
| 41 | + The full name path of the features target field in the predictions file. |
| 42 | + Formatted to be able to find nested columns, delimited by `.`. |
| 43 | + Alternatively referred to as the ground truth (or ground_truth_column) field. |
| 44 | + model (Optional[google.VertexModel]): |
| 45 | + The Model used for predictions job. |
| 46 | + Must share the same ancestor Location. |
| 47 | + prediction_score_column (Optional[str]): |
| 48 | + Optional. The column name of the field containing batch prediction scores. |
| 49 | + Formatted to be able to find nested columns, delimited by `.`. |
| 50 | + If not set, defaulted to `prediction.value` for a `point` forecasting_type and |
| 51 | + `prediction.quantile_predictions` for a `quantile` forecasting_type. |
| 52 | + forecasting_type (Optional[str]): |
| 53 | + Optional. If the problem_type is `forecasting`, then the forecasting type being addressed |
| 54 | + by this regression evaluation run. `point` and `quantile` are the supported types. |
| 55 | + If not set, defaulted to `point`. |
| 56 | + forecasting_quantiles (Optional[Sequence[Float]]): |
| 57 | + Required for a `quantile` forecasting_type. |
| 58 | + The list of quantiles in the same order appeared in the quantile prediction score column. |
| 59 | + If one of the quantiles is set to `0.5f`, point evaluation will be set on that index. |
| 60 | + example_weight_column (Optional[str]): |
| 61 | + Optional. The column name of the field containing example weights. |
| 62 | + Each value of positive_classes provided. |
| 63 | + point_evaluation_quantile (Optional[Float]): |
| 64 | + Required for a `quantile` forecasting_type. |
| 65 | + A quantile in the list of forecasting_quantiles that will be used for point evaluation |
| 66 | + metrics. |
| 67 | + dataflow_service_account (Optional[str]): |
| 68 | + Optional. Service account to run the dataflow job. |
| 69 | + If not set, dataflow will use the default woker service account. |
| 70 | + For more details, see https://cloud.google.com/dataflow/docs/concepts/security-and-permissions#default_worker_service_account |
| 71 | + dataflow_disk_size (Optional[int]): |
| 72 | + Optional. The disk size (in GB) of the machine executing the evaluation run. |
| 73 | + If not set, defaulted to `50`. |
| 74 | + dataflow_machine_type (Optional[str]): |
| 75 | + Optional. The machine type executing the evaluation run. |
| 76 | + If not set, defaulted to `n1-standard-4`. |
| 77 | + dataflow_workers_num (Optional[int]): |
| 78 | + Optional. The number of workers executing the evaluation run. |
| 79 | + If not set, defaulted to `10`. |
| 80 | + dataflow_max_workers_num (Optional[int]): |
| 81 | + Optional. The max number of workers executing the evaluation run. |
| 82 | + If not set, defaulted to `25`. |
| 83 | + dataflow_subnetwork (Optional[str]): |
| 84 | + Dataflow's fully qualified subnetwork name, when empty the default subnetwork will be |
| 85 | + used. More details: |
| 86 | + https://cloud.google.com/dataflow/docs/guides/specifying-networks#example_network_and_subnetwork_specifications |
| 87 | + dataflow_use_public_ips (Optional[bool]): |
| 88 | + Specifies whether Dataflow workers use public IP addresses. |
| 89 | + encryption_spec_key_name (Optional[str]): |
| 90 | + Customer-managed encryption key. |
| 91 | + Returns: |
| 92 | + evaluation_metrics (google.ForecastingMetrics): |
| 93 | + google.ForecastingMetrics artifact representing the forecasting evaluation metrics in GCS. |
| 94 | +inputs: |
| 95 | + - { name: project, type: String } |
| 96 | + - { name: location, type: String, default: "us-central1" } |
| 97 | + - { name: root_dir, type: system.Artifact } |
| 98 | + - { name: predictions_format, type: String, default: "jsonl" } |
| 99 | + - { name: predictions_gcs_source, type: Artifact, optional: True } |
| 100 | + - { name: predictions_bigquery_source, type: google.BQTable, optional: True } |
| 101 | + - { name: ground_truth_format, type: String, default: "jsonl" } |
| 102 | + - { name: ground_truth_gcs_source, type: JsonArray, default: "[]" } |
| 103 | + - { name: ground_truth_bigquery_source, type: String, default: "" } |
| 104 | + - { name: target_field_name, type: String } |
| 105 | + - { name: model, type: google.VertexModel, optional: True } |
| 106 | + - { name: prediction_score_column, type: String, default: "" } |
| 107 | + - { name: forecasting_type, type: String, default: "point" } |
| 108 | + - { name: forecasting_quantiles, type: JsonArray, default: "[0.5]" } |
| 109 | + - { name: example_weight_column, type: String, default: "" } |
| 110 | + - { name: point_evaluation_quantile, type: Float, default: 0.5 } |
| 111 | + - { name: dataflow_service_account, type: String, default: "" } |
| 112 | + - { name: dataflow_disk_size, type: Integer, default: 50 } |
| 113 | + - { name: dataflow_machine_type, type: String, default: "n1-standard-4" } |
| 114 | + - { name: dataflow_workers_num, type: Integer, default: 1 } |
| 115 | + - { name: dataflow_max_workers_num, type: Integer, default: 5 } |
| 116 | + - { name: dataflow_subnetwork, type: String, default: "" } |
| 117 | + - { name: dataflow_use_public_ips, type: Boolean, default: "true" } |
| 118 | + - { name: encryption_spec_key_name, type: String, default: "" } |
| 119 | +outputs: |
| 120 | + - { name: evaluation_metrics, type: google.ForecastingMetrics } |
| 121 | + - { name: gcp_resources, type: String } |
| 122 | +implementation: |
| 123 | + container: |
| 124 | + image: gcr.io/ml-pipeline/model-evaluation:v0.9 |
| 125 | + command: |
| 126 | + - python |
| 127 | + - /main.py |
| 128 | + args: |
| 129 | + - --setup_file |
| 130 | + - /setup.py |
| 131 | + - --json_mode |
| 132 | + - "true" |
| 133 | + - --project_id |
| 134 | + - { inputValue: project } |
| 135 | + - --location |
| 136 | + - { inputValue: location } |
| 137 | + - --problem_type |
| 138 | + - "forecasting" |
| 139 | + - --forecasting_type |
| 140 | + - { inputValue: forecasting_type } |
| 141 | + - --forecasting_quantiles |
| 142 | + - { inputValue: forecasting_quantiles } |
| 143 | + - --point_evaluation_quantile |
| 144 | + - { inputValue: point_evaluation_quantile } |
| 145 | + - --batch_prediction_format |
| 146 | + - { inputValue: predictions_format } |
| 147 | + - if: |
| 148 | + cond: {isPresent: predictions_gcs_source} |
| 149 | + then: |
| 150 | + - --batch_prediction_gcs_source |
| 151 | + - "{{$.inputs.artifacts['predictions_gcs_source'].uri}}" |
| 152 | + - if: |
| 153 | + cond: {isPresent: predictions_bigquery_source} |
| 154 | + then: |
| 155 | + - --batch_prediction_bigquery_source |
| 156 | + - "bq://{{$.inputs.artifacts['predictions_bigquery_source'].metadata['projectId']}}.{{$.inputs.artifacts['predictions_bigquery_source'].metadata['datasetId']}}.{{$.inputs.artifacts['predictions_bigquery_source'].metadata['tableId']}}" |
| 157 | + - if: |
| 158 | + cond: {isPresent: model} |
| 159 | + then: |
| 160 | + - --model_name |
| 161 | + - "{{$.inputs.artifacts['model'].metadata['resourceName']}}" |
| 162 | + - --ground_truth_format |
| 163 | + - { inputValue: ground_truth_format } |
| 164 | + - --ground_truth_gcs_source |
| 165 | + - { inputValue: ground_truth_gcs_source } |
| 166 | + - --ground_truth_bigquery_source |
| 167 | + - { inputValue: ground_truth_bigquery_source } |
| 168 | + - --root_dir |
| 169 | + - "{{$.inputs.artifacts['root_dir'].uri}}" |
| 170 | + - --target_field_name |
| 171 | + - "instance.{{$.inputs.parameters['target_field_name']}}" |
| 172 | + - --prediction_score_column |
| 173 | + - { inputValue: prediction_score_column } |
| 174 | + - --dataflow_job_prefix |
| 175 | + - "evaluation-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}" |
| 176 | + - --dataflow_service_account |
| 177 | + - { inputValue: dataflow_service_account } |
| 178 | + - --dataflow_disk_size |
| 179 | + - { inputValue: dataflow_disk_size } |
| 180 | + - --dataflow_machine_type |
| 181 | + - { inputValue: dataflow_machine_type } |
| 182 | + - --dataflow_workers_num |
| 183 | + - { inputValue: dataflow_workers_num } |
| 184 | + - --dataflow_max_workers_num |
| 185 | + - { inputValue: dataflow_max_workers_num } |
| 186 | + - --dataflow_subnetwork |
| 187 | + - { inputValue: dataflow_subnetwork } |
| 188 | + - --dataflow_use_public_ips |
| 189 | + - { inputValue: dataflow_use_public_ips } |
| 190 | + - --kms_key_name |
| 191 | + - { inputValue: encryption_spec_key_name } |
| 192 | + - --output_metrics_gcs_path |
| 193 | + - { outputUri: evaluation_metrics } |
| 194 | + - --gcp_resources |
| 195 | + - { outputPath: gcp_resources } |
| 196 | + - --executor_input |
| 197 | + - "{{$}}" |
0 commit comments