Skip to content

Commit 2a72f12

Browse files
author
Ubuntu
committed
feat: Introduce cache_key to sdk
Signed-off-by: Ze Mao <zemao@google.com>
1 parent 0eb67e1 commit 2a72f12

File tree

6 files changed

+68
-4
lines changed

6 files changed

+68
-4
lines changed

sdk/python/kfp/client/client.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ def run_pipeline(
686686
version_id: Optional[str] = None,
687687
pipeline_root: Optional[str] = None,
688688
enable_caching: Optional[bool] = None,
689+
cache_key: Optional[str] = '',
689690
service_account: Optional[str] = None,
690691
) -> kfp_server_api.V2beta1Run:
691692
"""Runs a specified pipeline.
@@ -709,6 +710,8 @@ def run_pipeline(
709710
is ``True`` for all tasks by default. If set, the
710711
setting applies to all tasks in the pipeline (overrides the
711712
compile time settings).
713+
cache_key (optional): Customized cache key for this task.
714+
If set, the cache_key will be used as the key for the task's cache.
712715
service_account: Specifies which Kubernetes service
713716
account to use for this run.
714717
@@ -721,6 +724,7 @@ def run_pipeline(
721724
pipeline_id=pipeline_id,
722725
version_id=version_id,
723726
enable_caching=enable_caching,
727+
cache_key=cache_key,
724728
pipeline_root=pipeline_root,
725729
)
726730

@@ -806,6 +810,7 @@ def create_recurring_run(
806810
enabled: bool = True,
807811
pipeline_root: Optional[str] = None,
808812
enable_caching: Optional[bool] = None,
813+
cache_key: Optional[str] = '',
809814
service_account: Optional[str] = None,
810815
) -> kfp_server_api.V2beta1RecurringRun:
811816
"""Creates a recurring run.
@@ -850,6 +855,8 @@ def create_recurring_run(
850855
different caching options for individual tasks. If set, the
851856
setting applies to all tasks in the pipeline (overrides the
852857
compile time settings).
858+
cache_key (optional): Customized cache key for this task.
859+
If set, the cache_key will be used as the key for the task's cache.
853860
service_account: Specifies which Kubernetes service
854861
account this recurring run uses.
855862
Returns:
@@ -862,6 +869,7 @@ def create_recurring_run(
862869
pipeline_id=pipeline_id,
863870
version_id=version_id,
864871
enable_caching=enable_caching,
872+
cache_key=cache_key,
865873
pipeline_root=pipeline_root,
866874
)
867875

@@ -908,6 +916,7 @@ def _create_job_config(
908916
pipeline_id: Optional[str],
909917
version_id: Optional[str],
910918
enable_caching: Optional[bool],
919+
cache_key: Optional[str],
911920
pipeline_root: Optional[str],
912921
) -> _JobConfig:
913922
"""Creates a JobConfig with spec and resource_references.
@@ -928,6 +937,8 @@ def _create_job_config(
928937
different caching options for individual tasks. If set, the
929938
setting applies to all tasks in the pipeline (overrides the
930939
compile time settings).
940+
cache_key (optional): Customized cache key for this task.
941+
If set, the cache_key will be used as the key for the task's cache.
931942
pipeline_root: Root path of the pipeline outputs.
932943
933944
Returns:
@@ -956,7 +967,7 @@ def _create_job_config(
956967
# settings.
957968
if enable_caching is not None:
958969
_override_caching_options(pipeline_doc.pipeline_spec,
959-
enable_caching)
970+
enable_caching, cache_key)
960971
pipeline_spec = pipeline_doc.to_dict()
961972

962973
pipeline_version_reference = None
@@ -983,6 +994,7 @@ def create_run_from_pipeline_func(
983994
namespace: Optional[str] = None,
984995
pipeline_root: Optional[str] = None,
985996
enable_caching: Optional[bool] = None,
997+
cache_key: Optional[str] = '',
986998
service_account: Optional[str] = None,
987999
experiment_id: Optional[str] = None,
9881000
) -> RunPipelineResult:
@@ -1004,6 +1016,8 @@ def create_run_from_pipeline_func(
10041016
different caching options for individual tasks. If set, the
10051017
setting applies to all tasks in the pipeline (overrides the
10061018
compile time settings).
1019+
cache_key (optional): Customized cache key for this task.
1020+
If set, the cache_key will be used as the key for the task's cache.
10071021
service_account: Specifies which Kubernetes service
10081022
account to use for this run.
10091023
experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name.
@@ -1032,6 +1046,7 @@ def create_run_from_pipeline_func(
10321046
namespace=namespace,
10331047
pipeline_root=pipeline_root,
10341048
enable_caching=enable_caching,
1049+
cache_key=cache_key,
10351050
service_account=service_account,
10361051
)
10371052

@@ -1044,6 +1059,7 @@ def create_run_from_pipeline_package(
10441059
namespace: Optional[str] = None,
10451060
pipeline_root: Optional[str] = None,
10461061
enable_caching: Optional[bool] = None,
1062+
cache_key: Optional[str] = '',
10471063
service_account: Optional[str] = None,
10481064
experiment_id: Optional[str] = None,
10491065
) -> RunPipelineResult:
@@ -1065,6 +1081,8 @@ def create_run_from_pipeline_package(
10651081
different caching options for individual tasks. If set, the
10661082
setting applies to all tasks in the pipeline (overrides the
10671083
compile time settings).
1084+
cache_key (optional): Customized cache key for this task.
1085+
If set, the cache_key will be used as the key for the task's cache.
10681086
service_account: Specifies which Kubernetes service
10691087
account to use for this run.
10701088
experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name.
@@ -1105,6 +1123,7 @@ def create_run_from_pipeline_package(
11051123
params=arguments,
11061124
pipeline_root=pipeline_root,
11071125
enable_caching=enable_caching,
1126+
cache_key=cache_key,
11081127
service_account=service_account,
11091128
)
11101129
return RunPipelineResult(self, run_info)
@@ -1681,6 +1700,7 @@ def _safe_load_yaml(stream: TextIO) -> _PipelineDoc:
16811700
def _override_caching_options(
16821701
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
16831702
enable_caching: bool,
1703+
cache_key: str = '',
16841704
) -> None:
16851705
"""Overrides caching options.
16861706
@@ -1690,3 +1710,4 @@ def _override_caching_options(
16901710
"""
16911711
for _, task_spec in pipeline_spec.root.dag.tasks.items():
16921712
task_spec.caching_options.enable_cache = enable_caching
1713+
task_spec.caching_options.cache_key = cache_key

sdk/python/kfp/client/client_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,19 @@ def pipeline_with_two_component(text: str = 'hi there'):
8888
pipeline_obj = yaml.safe_load(f)
8989
pipeline_spec = json_format.ParseDict(
9090
pipeline_obj, pipeline_spec_pb2.PipelineSpec())
91-
client._override_caching_options(pipeline_spec, True)
91+
client._override_caching_options(
92+
pipeline_spec, True, cache_key='OVERRIDE_KEY')
9293
pipeline_obj = json_format.MessageToDict(pipeline_spec)
9394
self.assertTrue(pipeline_obj['root']['dag']['tasks']
9495
['hello-word']['cachingOptions']['enableCache'])
96+
self.assertEqual(
97+
pipeline_obj['root']['dag']['tasks']['hello-word']
98+
['cachingOptions']['cacheKey'], 'OVERRIDE_KEY')
9599
self.assertTrue(pipeline_obj['root']['dag']['tasks']['to-lower']
96100
['cachingOptions']['enableCache'])
101+
self.assertEqual(
102+
pipeline_obj['root']['dag']['tasks']['to-lower']
103+
['cachingOptions']['cacheKey'], 'OVERRIDE_KEY')
97104

98105

99106
class TestExtractPipelineYAML(parameterized.TestCase):

sdk/python/kfp/compiler/compiler_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,32 @@ def my_pipeline():
10011001

10021002
self.assertTrue(caching_options['enableCache'])
10031003

1004+
def test_compile_pipeline_with_cache_key(self):
1005+
"""Test pipeline compilation with cache key."""
1006+
1007+
@dsl.component
1008+
def my_component():
1009+
pass
1010+
1011+
@dsl.pipeline(name='tiny-pipeline')
1012+
def my_pipeline():
1013+
my_task = my_component()
1014+
my_task.set_caching_options(True, cache_key='MY_KEY')
1015+
1016+
with tempfile.TemporaryDirectory() as tempdir:
1017+
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
1018+
compiler.Compiler().compile(
1019+
pipeline_func=my_pipeline, package_path=output_yaml)
1020+
1021+
with open(output_yaml, 'r') as f:
1022+
pipeline_spec = yaml.safe_load(f)
1023+
1024+
task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
1025+
caching_options = task_spec['cachingOptions']
1026+
1027+
self.assertTrue(caching_options['enableCache'])
1028+
self.assertEqual(caching_options['cacheKey'], 'MY_KEY')
1029+
10041030
def test_compile_pipeline_with_caching_disabled(self):
10051031
"""Test pipeline compilation with caching disabled."""
10061032

sdk/python/kfp/compiler/pipeline_spec_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def build_task_spec_for_task(
124124
utils.sanitize_component_name(task.name))
125125
pipeline_task_spec.caching_options.enable_cache = (
126126
task._task_spec.enable_caching)
127+
pipeline_task_spec.caching_options.cache_key = (task._task_spec.cache_key)
127128

128129
if task._task_spec.retry_policy is not None:
129130
pipeline_task_spec.retry_policy.CopyFrom(

sdk/python/kfp/dsl/pipeline_task.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
args: Dict[str, Any],
100100
execute_locally: bool = False,
101101
execution_caching_default: bool = True,
102+
execution_cache_key: str = '',
102103
) -> None:
103104
"""Initilizes a PipelineTask instance."""
104105
# import within __init__ to avoid circular import
@@ -131,7 +132,8 @@ def __init__(
131132
inputs=dict(args.items()),
132133
dependent_tasks=[],
133134
component_ref=component_spec.name,
134-
enable_caching=execution_caching_default)
135+
enable_caching=execution_caching_default,
136+
cache_key=execution_cache_key)
135137
self._run_after: List[str] = []
136138

137139
self.importer_spec = None
@@ -301,16 +303,20 @@ def _extract_container_spec_and_convert_placeholders(
301303
return container_spec
302304

303305
@block_if_final()
304-
def set_caching_options(self, enable_caching: bool) -> 'PipelineTask':
306+
def set_caching_options(self,
307+
enable_caching: bool,
308+
cache_key: str = '') -> 'PipelineTask':
305309
"""Sets caching options for the task.
306310
307311
Args:
308312
enable_caching: Whether to enable caching.
313+
cache_key: Customized cache key for this task.
309314
310315
Returns:
311316
Self return to allow chained setting calls.
312317
"""
313318
self._task_spec.enable_caching = enable_caching
319+
self._task_spec.cache_key = cache_key
314320
return self
315321

316322
def _ensure_container_spec_exists(self) -> None:

sdk/python/kfp/dsl/structures.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,8 @@ class TaskSpec:
409409
from the [items][] collection.
410410
enable_caching (optional): whether or not to enable caching for the task.
411411
Default is True.
412+
cache_key (optional): Customized cache key for this task.
413+
Default is empty string.
412414
display_name (optional): the display name of the task. If not specified,
413415
the task name will be used as the display name.
414416
"""
@@ -421,6 +423,7 @@ class TaskSpec:
421423
iterator_items: Optional[Any] = None
422424
iterator_item_input: Optional[str] = None
423425
enable_caching: bool = True
426+
cache_key: str = ''
424427
display_name: Optional[str] = None
425428
retry_policy: Optional[RetryPolicy] = None
426429

0 commit comments

Comments
 (0)