Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Introduce cache_key to sdk #11466

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion sdk/python/kfp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def run_pipeline(
version_id: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = '',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: default to None, and set the proto field only when it's not None

service_account: Optional[str] = None,
) -> kfp_server_api.V2beta1Run:
"""Runs a specified pipeline.
Expand All @@ -709,6 +710,8 @@ def run_pipeline(
is ``True`` for all tasks by default. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on why we would need a cache key? This doc string isn't providing a clear enough idea.

I would suggest adding some description to the PR, and also linking the related GitHub issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing!

This is for Issue #11328
Doc: https://docs.google.com/document/d/1oNgYyFYondaVSFf9Pd3Q9uVzaqBr5wrOMHKgX9MMa78/edit?tab=t.0

(Also updated this information in description)

service_account: Specifies which Kubernetes service
account to use for this run.

Expand All @@ -721,6 +724,7 @@ def run_pipeline(
pipeline_id=pipeline_id,
version_id=version_id,
enable_caching=enable_caching,
cache_key=cache_key,
pipeline_root=pipeline_root,
)

Expand Down Expand Up @@ -806,6 +810,7 @@ def create_recurring_run(
enabled: bool = True,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = '',
service_account: Optional[str] = None,
) -> kfp_server_api.V2beta1RecurringRun:
"""Creates a recurring run.
Expand Down Expand Up @@ -850,6 +855,8 @@ def create_recurring_run(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account this recurring run uses.
Returns:
Expand All @@ -862,6 +869,7 @@ def create_recurring_run(
pipeline_id=pipeline_id,
version_id=version_id,
enable_caching=enable_caching,
cache_key=cache_key,
pipeline_root=pipeline_root,
)

Expand Down Expand Up @@ -908,6 +916,7 @@ def _create_job_config(
pipeline_id: Optional[str],
version_id: Optional[str],
enable_caching: Optional[bool],
cache_key: Optional[str],
pipeline_root: Optional[str],
) -> _JobConfig:
"""Creates a JobConfig with spec and resource_references.
Expand All @@ -928,6 +937,8 @@ def _create_job_config(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
pipeline_root: Root path of the pipeline outputs.

Returns:
Expand Down Expand Up @@ -956,7 +967,7 @@ def _create_job_config(
# settings.
if enable_caching is not None:
_override_caching_options(pipeline_doc.pipeline_spec,
enable_caching)
enable_caching, cache_key)
pipeline_spec = pipeline_doc.to_dict()

pipeline_version_reference = None
Expand All @@ -983,6 +994,7 @@ def create_run_from_pipeline_func(
namespace: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = '',
service_account: Optional[str] = None,
experiment_id: Optional[str] = None,
) -> RunPipelineResult:
Expand All @@ -1004,6 +1016,8 @@ def create_run_from_pipeline_func(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.
experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name.
Expand Down Expand Up @@ -1032,6 +1046,7 @@ def create_run_from_pipeline_func(
namespace=namespace,
pipeline_root=pipeline_root,
enable_caching=enable_caching,
cache_key=cache_key,
service_account=service_account,
)

Expand All @@ -1044,6 +1059,7 @@ def create_run_from_pipeline_package(
namespace: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = '',
service_account: Optional[str] = None,
experiment_id: Optional[str] = None,
) -> RunPipelineResult:
Expand All @@ -1065,6 +1081,8 @@ def create_run_from_pipeline_package(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.
experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name.
Expand Down Expand Up @@ -1105,6 +1123,7 @@ def create_run_from_pipeline_package(
params=arguments,
pipeline_root=pipeline_root,
enable_caching=enable_caching,
cache_key=cache_key,
service_account=service_account,
)
return RunPipelineResult(self, run_info)
Expand Down Expand Up @@ -1681,6 +1700,7 @@ def _safe_load_yaml(stream: TextIO) -> _PipelineDoc:
def _override_caching_options(
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
enable_caching: bool,
cache_key: str = '',
) -> None:
"""Overrides caching options.

Expand All @@ -1690,3 +1710,4 @@ def _override_caching_options(
"""
for _, task_spec in pipeline_spec.root.dag.tasks.items():
task_spec.caching_options.enable_cache = enable_caching
task_spec.caching_options.cache_key = cache_key
9 changes: 8 additions & 1 deletion sdk/python/kfp/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,19 @@ def pipeline_with_two_component(text: str = 'hi there'):
pipeline_obj = yaml.safe_load(f)
pipeline_spec = json_format.ParseDict(
pipeline_obj, pipeline_spec_pb2.PipelineSpec())
client._override_caching_options(pipeline_spec, True)
client._override_caching_options(
pipeline_spec, True, cache_key='OVERRIDE_KEY')
pipeline_obj = json_format.MessageToDict(pipeline_spec)
self.assertTrue(pipeline_obj['root']['dag']['tasks']
['hello-word']['cachingOptions']['enableCache'])
self.assertEqual(
pipeline_obj['root']['dag']['tasks']['hello-word']
['cachingOptions']['cacheKey'], 'OVERRIDE_KEY')
self.assertTrue(pipeline_obj['root']['dag']['tasks']['to-lower']
['cachingOptions']['enableCache'])
self.assertEqual(
pipeline_obj['root']['dag']['tasks']['to-lower']
['cachingOptions']['cacheKey'], 'OVERRIDE_KEY')


class TestExtractPipelineYAML(parameterized.TestCase):
Expand Down
26 changes: 26 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,32 @@ def my_pipeline():

self.assertTrue(caching_options['enableCache'])

def test_compile_pipeline_with_cache_key(self):
"""Test pipeline compilation with cache key."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name='tiny-pipeline')
def my_pipeline():
my_task = my_component()
my_task.set_caching_options(True, cache_key='MY_KEY')

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)

task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec['cachingOptions']

self.assertTrue(caching_options['enableCache'])
self.assertEqual(caching_options['cacheKey'], 'MY_KEY')

def test_compile_pipeline_with_caching_disabled(self):
"""Test pipeline compilation with caching disabled."""

Expand Down
1 change: 1 addition & 0 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def build_task_spec_for_task(
utils.sanitize_component_name(task.name))
pipeline_task_spec.caching_options.enable_cache = (
task._task_spec.enable_caching)
pipeline_task_spec.caching_options.cache_key = (task._task_spec.cache_key)

if task._task_spec.retry_policy is not None:
pipeline_task_spec.retry_policy.CopyFrom(
Expand Down
10 changes: 8 additions & 2 deletions sdk/python/kfp/dsl/pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
args: Dict[str, Any],
execute_locally: bool = False,
execution_caching_default: bool = True,
execution_cache_key: str = '',
) -> None:
"""Initilizes a PipelineTask instance."""
# import within __init__ to avoid circular import
Expand Down Expand Up @@ -131,7 +132,8 @@ def __init__(
inputs=dict(args.items()),
dependent_tasks=[],
component_ref=component_spec.name,
enable_caching=execution_caching_default)
enable_caching=execution_caching_default,
cache_key=execution_cache_key)
self._run_after: List[str] = []

self.importer_spec = None
Expand Down Expand Up @@ -301,16 +303,20 @@ def _extract_container_spec_and_convert_placeholders(
return container_spec

@block_if_final()
def set_caching_options(self, enable_caching: bool) -> 'PipelineTask':
def set_caching_options(self,
enable_caching: bool,
cache_key: str = '') -> 'PipelineTask':
"""Sets caching options for the task.

Args:
enable_caching: Whether to enable caching.
cache_key: Customized cache key for this task.

Returns:
Self return to allow chained setting calls.
"""
self._task_spec.enable_caching = enable_caching
self._task_spec.cache_key = cache_key
return self

def _ensure_container_spec_exists(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/kfp/dsl/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ class TaskSpec:
from the [items][] collection.
enable_caching (optional): whether or not to enable caching for the task.
Default is True.
cache_key (optional): Customized cache key for this task.
Default is empty string.
display_name (optional): the display name of the task. If not specified,
the task name will be used as the display name.
"""
Expand All @@ -421,6 +423,7 @@ class TaskSpec:
iterator_items: Optional[Any] = None
iterator_item_input: Optional[str] = None
enable_caching: bool = True
cache_key: str = ''
display_name: Optional[str] = None
retry_policy: Optional[RetryPolicy] = None

Expand Down
Loading