Skip to content

Commit 522b593

Browse files
author
Googler
committed
feat(components): internal
Signed-off-by: Googler <nobody@google.com> PiperOrigin-RevId: 658993845
1 parent 1612dac commit 522b593

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""KFP DSL compiler with Vertex Specific Features."""
2+
3+
from google_cloud_pipeline_components.preview.compiler import Compiler
4+
5+
__all__ = [
6+
'Compiler',
7+
]
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""KFP DSL compiler with Vertex Specific Features.
2+
3+
This compiler is intended to compile PipelineSpec with Vertex Specifc features.
4+
5+
To ensure full compatibility with Vertex specifc functionalities,
6+
Google first party pipelines should utilize this version of compiler.
7+
"""
8+
9+
import os
10+
from os import path
11+
from typing import Any, Dict, Optional
12+
13+
from absl import logging
14+
from google.protobuf import json_format
15+
from google_cloud_pipeline_components.proto import template_metadata_pb2
16+
from kfp import compiler as kfp_compiler
17+
from kfp.dsl import base_component
18+
from kfp.pipeline_spec import pipeline_spec_pb2
19+
import yaml
20+
21+
22+
class Compiler:
23+
"""Compiles pipelines composed using the KFP SDK DSL to a YAML pipeline definition.
24+
25+
The pipeline definition is `PipelineSpec IR
26+
<https://github.com/kubeflow/pipelines/blob/2060e38c5591806d657d85b53eed2eef2e5de2ae/api/v2alpha1/pipeline_spec.proto#L50>`_,
27+
the protobuf message that defines a pipeline.
28+
29+
Example:
30+
::
31+
32+
@dsl.pipeline(
33+
name='name',
34+
)
35+
def my_pipeline(a: int, b: str = 'default value'):
36+
...
37+
38+
compiler.Compiler().compile(
39+
pipeline_func=my_pipeline,
40+
package_path='path/to/pipeline.yaml',
41+
pipeline_parameters={'a': 1},
42+
)
43+
"""
44+
45+
def merge_template_metadata_to_pipeline_spec_proto(
46+
self,
47+
template_metadata: Optional[template_metadata_pb2.TemplateMetadata],
48+
pipeline_spec_proto: pipeline_spec_pb2.PipelineSpec,
49+
) -> pipeline_spec_pb2.PipelineSpec:
50+
"""Merges TemplateMetadata into a PipelineSpec for execution on Google Cloud.
51+
52+
This function prepares a PipelineSpec for execution on Google Cloud by
53+
incorporating TemplateMetadata into the platform-specific configuration. The
54+
metadata is converted to JSON and embedded within the 'google_cloud'
55+
platform
56+
configuration.
57+
58+
Args:
59+
template_metadata: A TemplateMetadata object containing component
60+
metadata.
61+
pipeline_spec_proto: A PipelineSpec protobuf object representing the
62+
pipeline specification.
63+
64+
Returns:
65+
A modified PipelineSpec protobuf object with the TemplateMetadata merged
66+
into the 'google_cloud' PlatformSpec configuration or the original
67+
PlatformSpec proto if the template_metadata is none.
68+
"""
69+
if template_metadata is None:
70+
return pipeline_spec_proto
71+
template_metadata_json = json_format.MessageToJson(template_metadata)
72+
platform_spec_proto = pipeline_spec_pb2.PlatformSpec()
73+
platform_spec_proto.platform = "google_cloud"
74+
json_format.Parse(template_metadata_json, platform_spec_proto.config)
75+
pipeline_spec_proto.root.platform_specs.append(platform_spec_proto)
76+
return pipeline_spec_proto
77+
78+
def parse_pipeline_spec_yaml(
79+
self,
80+
pipeline_spec_yaml_file: str,
81+
) -> pipeline_spec_pb2.PipelineSpec:
82+
"""Parse pipeline spec yaml parses to the proto.
83+
84+
Args:
85+
pipeline_spec_yaml_file: Path to the pipeline spec yaml file.
86+
87+
Returns:
88+
Proto parsed.
89+
90+
Raises:
91+
ValueError: When the PipelineSpec is invalid.
92+
"""
93+
with open(pipeline_spec_yaml_file, "r") as f:
94+
pipeline_spec_yaml = f.read()
95+
pipeline_spec_dict = yaml.safe_load(pipeline_spec_yaml)
96+
pipeline_spec_proto = pipeline_spec_pb2.PipelineSpec()
97+
try:
98+
json_format.ParseDict(pipeline_spec_dict, pipeline_spec_proto)
99+
except json_format.ParseError as e:
100+
raise ValueError(
101+
"Failed to parse %s . Please check if that is a valid YAML file"
102+
" parsing a pipelineSpec proto." % pipeline_spec_yaml_file
103+
) from e
104+
if not pipeline_spec_proto.HasField("pipeline_info"):
105+
raise ValueError(
106+
"PipelineInfo field not found in the pipeline spec YAML file %s."
107+
% pipeline_spec_yaml_file
108+
)
109+
if not pipeline_spec_proto.pipeline_info.display_name:
110+
logging.warning(
111+
(
112+
"PipelineInfo.displayName field is empty in pipeline spec YAML"
113+
" file %s."
114+
),
115+
pipeline_spec_yaml_file,
116+
)
117+
if not pipeline_spec_proto.pipeline_info.description:
118+
logging.warning(
119+
(
120+
"PipelineInfo.description field is empty in pipeline spec YAML"
121+
" file %s."
122+
),
123+
pipeline_spec_yaml_file,
124+
)
125+
return pipeline_spec_proto
126+
127+
def compile(
128+
self,
129+
pipeline_func: base_component.BaseComponent,
130+
package_path: str,
131+
pipeline_name: Optional[str] = None,
132+
pipeline_parameters: Optional[Dict[str, Any]] = None,
133+
type_check: bool = True,
134+
includ_vertex_specifc_features=True,
135+
) -> None:
136+
"""Compiles the pipeline or component function into IR YAML.
137+
138+
By default, this compiler will compile any Vertex Specifc Features.
139+
140+
Args:
141+
pipeline_func: Pipeline function constructed with the ``@dsl.pipeline``
142+
or component constructed with the ``@dsl.component`` decorator.
143+
package_path: Output YAML file path. For example,
144+
``'~/my_pipeline.yaml'`` or ``'~/my_component.yaml'``.
145+
pipeline_name: Name of the pipeline.
146+
pipeline_parameters: Map of parameter names to argument values.
147+
type_check: Whether to enable type checking of component interfaces
148+
during compilation.
149+
includ_vertex_specifc_features: Whether to enable compiling Vertex
150+
Specific Features.
151+
"""
152+
if not includ_vertex_specifc_features:
153+
kfp_compiler.Compiler().compile(
154+
pipeline_func=pipeline_func,
155+
package_path=package_path,
156+
pipeline_name=pipeline_name,
157+
pipeline_parameters=pipeline_parameters,
158+
type_check=type_check,
159+
)
160+
return
161+
162+
local_temp_output_dir = path.join(path.dirname(package_path), "tmp.yaml")
163+
164+
kfp_compiler.Compiler().compile(
165+
pipeline_func=pipeline_func,
166+
package_path=local_temp_output_dir,
167+
pipeline_name=pipeline_name,
168+
pipeline_parameters=pipeline_parameters,
169+
type_check=type_check,
170+
)
171+
172+
original_pipeline_spec = self.parse_pipeline_spec_yaml(
173+
local_temp_output_dir
174+
)
175+
template_metadata = getattr(pipeline_func, "template_metadata", None)
176+
updated_pipeline_spec = self.merge_template_metadata_to_pipeline_spec_proto(
177+
template_metadata, original_pipeline_spec
178+
)
179+
updated_pipeline_spec_dict = json_format.MessageToDict(
180+
updated_pipeline_spec
181+
)
182+
183+
with open(
184+
package_path,
185+
"w",
186+
) as f:
187+
yaml.dump(updated_pipeline_spec_dict, f)
188+
189+
os.remove(local_temp_output_dir)

0 commit comments

Comments
 (0)